cubecl_ir/
variable.rs

1use core::{fmt::Display, hash::Hash};
2
3use crate::{BarrierLevel, FloatKind, IntKind, StorageType, TypeHash};
4
5use super::{ElemType, Matrix, Type, UIntKind};
6use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
7use derive_more::From;
8use float_ord::FloatOrd;
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
12#[allow(missing_docs)]
13pub struct Variable {
14    pub kind: VariableKind,
15    pub ty: Type,
16}
17
18impl Variable {
19    pub fn new(kind: VariableKind, item: Type) -> Self {
20        Self { kind, ty: item }
21    }
22
23    pub fn builtin(builtin: Builtin, ty: StorageType) -> Self {
24        Self::new(VariableKind::Builtin(builtin), Type::new(ty))
25    }
26
27    pub fn constant(value: ConstantValue, ty: impl Into<Type>) -> Self {
28        let ty = ty.into();
29        let value = value.cast_to(ty);
30        Self::new(VariableKind::Constant(value), ty)
31    }
32
33    pub fn elem_type(&self) -> ElemType {
34        self.ty.elem_type()
35    }
36
37    pub fn storage_type(&self) -> StorageType {
38        self.ty.storage_type()
39    }
40}
41
42pub type Id = u32;
43
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
46pub enum VariableKind {
47    GlobalInputArray(Id),
48    GlobalOutputArray(Id),
49    GlobalScalar(Id),
50    TensorMapInput(Id),
51    TensorMapOutput(Id),
52    LocalArray {
53        id: Id,
54        length: usize,
55        unroll_factor: usize,
56    },
57    LocalMut {
58        id: Id,
59    },
60    LocalConst {
61        id: Id,
62    },
63    Versioned {
64        id: Id,
65        version: u16,
66    },
67    Constant(ConstantValue),
68    ConstantArray {
69        id: Id,
70        length: usize,
71        unroll_factor: usize,
72    },
73    SharedArray {
74        id: Id,
75        length: usize,
76        unroll_factor: usize,
77        alignment: Option<usize>,
78    },
79    Shared {
80        id: Id,
81    },
82    Matrix {
83        id: Id,
84        mat: Matrix,
85    },
86    Builtin(Builtin),
87    Pipeline {
88        id: Id,
89        num_stages: u8,
90    },
91    BarrierToken {
92        id: Id,
93        level: BarrierLevel,
94    },
95}
96
97#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
99#[repr(u8)]
100pub enum Builtin {
101    UnitPos,
102    UnitPosX,
103    UnitPosY,
104    UnitPosZ,
105    CubePosCluster,
106    CubePosClusterX,
107    CubePosClusterY,
108    CubePosClusterZ,
109    CubePos,
110    CubePosX,
111    CubePosY,
112    CubePosZ,
113    CubeDim,
114    CubeDimX,
115    CubeDimY,
116    CubeDimZ,
117    CubeClusterDim,
118    CubeClusterDimX,
119    CubeClusterDimY,
120    CubeClusterDimZ,
121    CubeCount,
122    CubeCountX,
123    CubeCountY,
124    CubeCountZ,
125    PlaneDim,
126    UnitPosPlane,
127    AbsolutePos,
128    AbsolutePosX,
129    AbsolutePosY,
130    AbsolutePosZ,
131}
132
133impl Variable {
134    /// Whether a variable is always immutable. Used for optimizations to determine whether it's
135    /// safe to inline/merge
136    pub fn is_immutable(&self) -> bool {
137        match self.kind {
138            VariableKind::GlobalOutputArray { .. } => false,
139            VariableKind::TensorMapInput(_) => true,
140            VariableKind::TensorMapOutput(_) => false,
141            VariableKind::LocalMut { .. } => false,
142            VariableKind::SharedArray { .. } => false,
143            VariableKind::Shared { .. } => false,
144            VariableKind::Matrix { .. } => false,
145            VariableKind::LocalArray { .. } => false,
146            VariableKind::GlobalInputArray { .. } => false,
147            VariableKind::GlobalScalar { .. } => true,
148            VariableKind::Versioned { .. } => true,
149            VariableKind::LocalConst { .. } => true,
150            VariableKind::Constant(_) => true,
151            VariableKind::ConstantArray { .. } => true,
152            VariableKind::Builtin(_) => true,
153            VariableKind::Pipeline { .. } => false,
154            VariableKind::BarrierToken { .. } => false,
155        }
156    }
157
158    /// Is this an array type that yields [`Item`]s when indexed, or a scalar/vector that yields
159    /// [`Elem`]s when indexed?
160    pub fn is_array(&self) -> bool {
161        matches!(
162            self.kind,
163            VariableKind::GlobalInputArray { .. }
164                | VariableKind::GlobalOutputArray { .. }
165                | VariableKind::ConstantArray { .. }
166                | VariableKind::SharedArray { .. }
167                | VariableKind::LocalArray { .. }
168                | VariableKind::Matrix { .. }
169        )
170    }
171
172    pub fn has_length(&self) -> bool {
173        matches!(
174            self.kind,
175            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
176        )
177    }
178
179    pub fn has_buffer_length(&self) -> bool {
180        matches!(
181            self.kind,
182            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
183        )
184    }
185
186    /// Determines if the value is a constant with the specified value (converted if necessary)
187    pub fn is_constant(&self, value: i64) -> bool {
188        match self.kind {
189            VariableKind::Constant(ConstantValue::Int(val)) => val == value,
190            VariableKind::Constant(ConstantValue::UInt(val)) => val as i64 == value,
191            VariableKind::Constant(ConstantValue::Float(val)) => val == value as f64,
192            _ => false,
193        }
194    }
195
196    /// Determines if the value is a boolean constant with the `true` value
197    pub fn is_true(&self) -> bool {
198        match self.kind {
199            VariableKind::Constant(ConstantValue::Bool(val)) => val,
200            _ => false,
201        }
202    }
203
204    /// Determines if the value is a boolean constant with the `false` value
205    pub fn is_false(&self) -> bool {
206        match self.kind {
207            VariableKind::Constant(ConstantValue::Bool(val)) => !val,
208            _ => false,
209        }
210    }
211}
212
213/// The scalars are stored with the highest precision possible, but they might get reduced during
214/// compilation. For constant propagation, casts are always executed before converting back to the
215/// larger type to ensure deterministic output.
216#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
217#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd, From)]
218#[allow(missing_docs, clippy::derive_ord_xor_partial_ord)]
219pub enum ConstantValue {
220    Int(i64),
221    Float(f64),
222    UInt(u64),
223    Bool(bool),
224}
225
226impl Ord for ConstantValue {
227    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
228        // Override float-float comparison with `FloatOrd` since `f64` isn't `Ord`. All other
229        // comparisons are safe to unwrap since they're either `Ord` or only compare discriminants.
230        match (self, other) {
231            (ConstantValue::Float(this), ConstantValue::Float(other)) => {
232                FloatOrd(*this).cmp(&FloatOrd(*other))
233            }
234            _ => self.partial_cmp(other).unwrap(),
235        }
236    }
237}
238
239impl Eq for ConstantValue {}
240impl Hash for ConstantValue {
241    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
242        core::mem::discriminant(self).hash(ra_expand_state);
243        match self {
244            ConstantValue::Int(f0) => {
245                f0.hash(ra_expand_state);
246            }
247            ConstantValue::Float(f0) => {
248                FloatOrd(*f0).hash(ra_expand_state);
249            }
250            ConstantValue::UInt(f0) => {
251                f0.hash(ra_expand_state);
252            }
253            ConstantValue::Bool(f0) => {
254                f0.hash(ra_expand_state);
255            }
256        }
257    }
258}
259
260impl ConstantValue {
261    /// Returns the value of the constant as a usize.
262    ///
263    /// It will return [None] if the constant type is a float or a bool.
264    pub fn try_as_usize(&self) -> Option<usize> {
265        match self {
266            ConstantValue::UInt(val) => Some(*val as usize),
267            ConstantValue::Int(val) => Some(*val as usize),
268            ConstantValue::Float(_) => None,
269            ConstantValue::Bool(_) => None,
270        }
271    }
272
273    /// Returns the value of the constant as a usize.
274    pub fn as_usize(&self) -> usize {
275        match self {
276            ConstantValue::UInt(val) => *val as usize,
277            ConstantValue::Int(val) => *val as usize,
278            ConstantValue::Float(val) => *val as usize,
279            ConstantValue::Bool(val) => *val as usize,
280        }
281    }
282
283    /// Returns the value of the scalar as a u32.
284    ///
285    /// It will return [None] if the scalar type is a float or a bool.
286    pub fn try_as_u32(&self) -> Option<u32> {
287        self.try_as_u64().map(|it| it as u32)
288    }
289
290    /// Returns the value of the scalar as a u32.
291    ///
292    /// It will panic if the scalar type is a float or a bool.
293    pub fn as_u32(&self) -> u32 {
294        self.as_u64() as u32
295    }
296
297    /// Returns the value of the scalar as a u64.
298    ///
299    /// It will return [None] if the scalar type is a float or a bool.
300    pub fn try_as_u64(&self) -> Option<u64> {
301        match self {
302            ConstantValue::UInt(val) => Some(*val),
303            ConstantValue::Int(val) => Some(*val as u64),
304            ConstantValue::Float(_) => None,
305            ConstantValue::Bool(_) => None,
306        }
307    }
308
309    /// Returns the value of the scalar as a u64.
310    pub fn as_u64(&self) -> u64 {
311        match self {
312            ConstantValue::UInt(val) => *val,
313            ConstantValue::Int(val) => *val as u64,
314            ConstantValue::Float(val) => *val as u64,
315            ConstantValue::Bool(val) => *val as u64,
316        }
317    }
318
319    /// Returns the value of the scalar as a i64.
320    ///
321    /// It will return [None] if the scalar type is a float or a bool.
322    pub fn try_as_i64(&self) -> Option<i64> {
323        match self {
324            ConstantValue::UInt(val) => Some(*val as i64),
325            ConstantValue::Int(val) => Some(*val),
326            ConstantValue::Float(_) => None,
327            ConstantValue::Bool(_) => None,
328        }
329    }
330
331    /// Returns the value of the scalar as a i128.
332    pub fn as_i128(&self) -> i128 {
333        match self {
334            ConstantValue::UInt(val) => *val as i128,
335            ConstantValue::Int(val) => *val as i128,
336            ConstantValue::Float(val) => *val as i128,
337            ConstantValue::Bool(val) => *val as i128,
338        }
339    }
340
341    /// Returns the value of the scalar as a i64.
342    pub fn as_i64(&self) -> i64 {
343        match self {
344            ConstantValue::UInt(val) => *val as i64,
345            ConstantValue::Int(val) => *val,
346            ConstantValue::Float(val) => *val as i64,
347            ConstantValue::Bool(val) => *val as i64,
348        }
349    }
350
351    /// Returns the value of the scalar as a f64.
352    ///
353    /// It will return [None] if the scalar type is an int or a bool.
354    pub fn try_as_f64(&self) -> Option<f64> {
355        match self {
356            ConstantValue::Float(val) => Some(*val),
357            _ => None,
358        }
359    }
360
361    /// Returns the value of the scalar as a f64.
362    pub fn as_f64(&self) -> f64 {
363        match self {
364            ConstantValue::UInt(val) => *val as f64,
365            ConstantValue::Int(val) => *val as f64,
366            ConstantValue::Float(val) => *val,
367            ConstantValue::Bool(val) => *val as u8 as f64,
368        }
369    }
370
371    /// Returns the value of the variable as a bool if it actually is a bool.
372    pub fn try_as_bool(&self) -> Option<bool> {
373        match self {
374            ConstantValue::Bool(val) => Some(*val),
375            _ => None,
376        }
377    }
378
379    /// Returns the value of the variable as a bool.
380    ///
381    /// It will panic if the scalar isn't a bool.
382    pub fn as_bool(&self) -> bool {
383        match self {
384            ConstantValue::UInt(val) => *val != 0,
385            ConstantValue::Int(val) => *val != 0,
386            ConstantValue::Float(val) => *val != 0.,
387            ConstantValue::Bool(val) => *val,
388        }
389    }
390
391    pub fn is_zero(&self) -> bool {
392        match self {
393            ConstantValue::Int(val) => *val == 0,
394            ConstantValue::Float(val) => *val == 0.0,
395            ConstantValue::UInt(val) => *val == 0,
396            ConstantValue::Bool(val) => !*val,
397        }
398    }
399
400    pub fn is_one(&self) -> bool {
401        match self {
402            ConstantValue::Int(val) => *val == 1,
403            ConstantValue::Float(val) => *val == 1.0,
404            ConstantValue::UInt(val) => *val == 1,
405            ConstantValue::Bool(val) => *val,
406        }
407    }
408
409    pub fn cast_to(&self, other: impl Into<Type>) -> ConstantValue {
410        match other.into().storage_type() {
411            StorageType::Scalar(elem_type) => match elem_type {
412                ElemType::Float(kind) => match kind {
413                    FloatKind::E2M1 => e2m1::from_f64(self.as_f64()).to_f64(),
414                    FloatKind::E2M3 | FloatKind::E3M2 => {
415                        unimplemented!("FP6 constants not yet supported")
416                    }
417                    FloatKind::E4M3 => e4m3::from_f64(self.as_f64()).to_f64(),
418                    FloatKind::E5M2 => e5m2::from_f64(self.as_f64()).to_f64(),
419                    FloatKind::UE8M0 => ue8m0::from_f64(self.as_f64()).to_f64(),
420                    FloatKind::F16 => half::f16::from_f64(self.as_f64()).to_f64(),
421                    FloatKind::BF16 => half::bf16::from_f64(self.as_f64()).to_f64(),
422                    FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => {
423                        self.as_f64() as f32 as f64
424                    }
425                    FloatKind::F64 => self.as_f64(),
426                }
427                .into(),
428                ElemType::Int(kind) => match kind {
429                    IntKind::I8 => self.as_i64() as i8 as i64,
430                    IntKind::I16 => self.as_i64() as i16 as i64,
431                    IntKind::I32 => self.as_i64() as i32 as i64,
432                    IntKind::I64 => self.as_i64(),
433                }
434                .into(),
435                ElemType::UInt(kind) => match kind {
436                    UIntKind::U8 => self.as_u64() as u8 as u64,
437                    UIntKind::U16 => self.as_u64() as u16 as u64,
438                    UIntKind::U32 => self.as_u64() as u32 as u64,
439                    UIntKind::U64 => self.as_u64(),
440                }
441                .into(),
442                ElemType::Bool => self.as_bool().into(),
443            },
444            StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2) => {
445                e2m1::from_f64(self.as_f64()).to_f64().into()
446            }
447            StorageType::Packed(..) => unimplemented!("Unsupported packed type"),
448            StorageType::Atomic(_) => unimplemented!("Atomic constants aren't supported"),
449            StorageType::Opaque(_) => unimplemented!("Opaque constants aren't supported"),
450        }
451    }
452}
453
454impl Display for ConstantValue {
455    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
456        match self {
457            ConstantValue::Int(val) => write!(f, "{val}"),
458            ConstantValue::Float(val) => write!(f, "{val:?}"),
459            ConstantValue::UInt(val) => write!(f, "{val}"),
460            ConstantValue::Bool(val) => write!(f, "{val}"),
461        }
462    }
463}
464
465impl Variable {
466    pub fn line_size(&self) -> usize {
467        self.ty.line_size()
468    }
469
470    pub fn index(&self) -> Option<Id> {
471        match self.kind {
472            VariableKind::GlobalInputArray(id)
473            | VariableKind::GlobalOutputArray(id)
474            | VariableKind::TensorMapInput(id)
475            | VariableKind::TensorMapOutput(id)
476            | VariableKind::GlobalScalar(id)
477            | VariableKind::LocalMut { id, .. }
478            | VariableKind::Versioned { id, .. }
479            | VariableKind::LocalConst { id, .. }
480            | VariableKind::ConstantArray { id, .. }
481            | VariableKind::SharedArray { id, .. }
482            | VariableKind::Shared { id, .. }
483            | VariableKind::LocalArray { id, .. }
484            | VariableKind::Matrix { id, .. } => Some(id),
485            _ => None,
486        }
487    }
488
489    pub fn as_const(&self) -> Option<ConstantValue> {
490        match self.kind {
491            VariableKind::Constant(constant) => Some(constant),
492            _ => None,
493        }
494    }
495}
496
497impl Display for Variable {
498    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
499        match self.kind {
500            VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
501            VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
502            VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
503            VariableKind::TensorMapInput(id) => write!(f, "tensor_map({id})"),
504            VariableKind::TensorMapOutput(id) => write!(f, "tensor_map({id})"),
505            VariableKind::Constant(constant) => write!(f, "{}({constant})", self.ty),
506            VariableKind::LocalMut { id } => write!(f, "local({id})"),
507            VariableKind::Versioned { id, version } => {
508                write!(f, "local({id}).v{version}")
509            }
510            VariableKind::LocalConst { id } => write!(f, "binding({id})"),
511            VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
512            VariableKind::SharedArray { id, .. } => write!(f, "shared_array({id})"),
513            VariableKind::Shared { id } => write!(f, "shared({id})"),
514            VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
515            VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
516            VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
517            VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
518            VariableKind::BarrierToken { id, .. } => write!(f, "barrier_token({id})"),
519        }
520    }
521}
522
523// Useful with the cube_inline macro.
524impl From<&Variable> for Variable {
525    fn from(value: &Variable) -> Self {
526        *value
527    }
528}