cubecl_cpp/shared/
variable.rs

1use cubecl_core::ir::{self as gpu, BarrierLevel, ConstantScalarValue, Id};
2use std::fmt::Display;
3
4use super::{COUNTER_TMP_VAR, Dialect, Elem, Fragment, FragmentIdent, Item};
5
6pub trait Component<D: Dialect>: Display + FmtLeft {
7    fn item(&self) -> Item<D>;
8    fn is_const(&self) -> bool;
9    fn index(&self, index: usize) -> IndexedVariable<D>;
10    fn elem(&self) -> Elem<D> {
11        *self.item().elem()
12    }
13}
14
15pub trait FmtLeft: Display {
16    fn fmt_left(&self) -> String;
17}
18
19#[derive(new)]
20pub struct OptimizedArgs<const N: usize, D: Dialect> {
21    pub args: [Variable<D>; N],
22    pub optimization_factor: Option<usize>,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq)]
26pub enum Variable<D: Dialect> {
27    AbsolutePos,
28    AbsolutePosBaseName, // base name for XYZ
29    AbsolutePosX,
30    AbsolutePosY,
31    AbsolutePosZ,
32    UnitPos,
33    UnitPosBaseName, // base name for XYZ
34    UnitPosX,
35    UnitPosY,
36    UnitPosZ,
37    CubePos,
38    CubePosBaseName, // base name for XYZ
39    CubePosX,
40    CubePosY,
41    CubePosZ,
42    CubeDim,
43    CubeDimBaseName, // base name for XYZ
44    CubeDimX,
45    CubeDimY,
46    CubeDimZ,
47    CubeCount,
48    CubeCountBaseName, // base name for XYZ
49    CubeCountX,
50    CubeCountY,
51    CubeCountZ,
52    PlaneDim,
53    PlaneDimChecked,
54    PlanePos,
55    UnitPosPlane,
56    ClusterRank,
57    ClusterIndexX,
58    ClusterIndexY,
59    ClusterIndexZ,
60    GlobalInputArray(Id, Item<D>),
61    GlobalOutputArray(Id, Item<D>),
62    GlobalScalar {
63        id: Id,
64        elem: Elem<D>,
65        in_struct: bool,
66    },
67    ConstantArray(Id, Item<D>, u32),
68    ConstantScalar(ConstantScalarValue, Elem<D>),
69    TensorMap(Id),
70    LocalMut {
71        id: Id,
72        item: Item<D>,
73    },
74    LocalConst {
75        id: Id,
76        item: Item<D>,
77    },
78    Named {
79        name: &'static str,
80        item: Item<D>,
81    },
82    Slice {
83        id: Id,
84        item: Item<D>,
85    },
86    SharedMemory(Id, Item<D>, u32),
87    LocalArray(Id, Item<D>, u32),
88    WmmaFragment {
89        id: Id,
90        frag: Fragment<D>,
91    },
92    Pipeline {
93        id: Id,
94        item: Item<D>,
95    },
96    Barrier {
97        id: Id,
98        item: Item<D>,
99        level: BarrierLevel,
100    },
101    Tmp {
102        id: Id,
103        item: Item<D>,
104    },
105}
106
107impl<D: Dialect> Component<D> for Variable<D> {
108    fn index(&self, index: usize) -> IndexedVariable<D> {
109        self.index(index)
110    }
111
112    fn item(&self) -> Item<D> {
113        match self {
114            Variable::AbsolutePos => Item::scalar(Elem::U32, true),
115            Variable::AbsolutePosBaseName => Item {
116                elem: Elem::U32,
117                vectorization: 3,
118                native: true,
119            },
120            Variable::AbsolutePosX => Item::scalar(Elem::U32, true),
121            Variable::AbsolutePosY => Item::scalar(Elem::U32, true),
122            Variable::AbsolutePosZ => Item::scalar(Elem::U32, true),
123            Variable::CubeCount => Item::scalar(Elem::U32, true),
124            Variable::CubeCountBaseName => Item {
125                elem: Elem::U32,
126                vectorization: 3,
127                native: true,
128            },
129            Variable::CubeCountX => Item::scalar(Elem::U32, true),
130            Variable::CubeCountY => Item::scalar(Elem::U32, true),
131            Variable::CubeCountZ => Item::scalar(Elem::U32, true),
132            Variable::CubeDimBaseName => Item {
133                elem: Elem::U32,
134                vectorization: 3,
135                native: true,
136            },
137            Variable::CubeDim => Item::scalar(Elem::U32, true),
138            Variable::CubeDimX => Item::scalar(Elem::U32, true),
139            Variable::CubeDimY => Item::scalar(Elem::U32, true),
140            Variable::CubeDimZ => Item::scalar(Elem::U32, true),
141            Variable::CubePos => Item::scalar(Elem::U32, true),
142            Variable::CubePosBaseName => Item {
143                elem: Elem::U32,
144                vectorization: 3,
145                native: true,
146            },
147            Variable::CubePosX => Item::scalar(Elem::U32, true),
148            Variable::CubePosY => Item::scalar(Elem::U32, true),
149            Variable::CubePosZ => Item::scalar(Elem::U32, true),
150            Variable::UnitPos => Item::scalar(Elem::U32, true),
151            Variable::UnitPosBaseName => Item {
152                elem: Elem::U32,
153                vectorization: 3,
154                native: true,
155            },
156            Variable::UnitPosX => Item::scalar(Elem::U32, true),
157            Variable::UnitPosY => Item::scalar(Elem::U32, true),
158            Variable::UnitPosZ => Item::scalar(Elem::U32, true),
159            Variable::PlaneDim => Item::scalar(Elem::U32, true),
160            Variable::PlaneDimChecked => Item::scalar(Elem::U32, true),
161            Variable::PlanePos => Item::scalar(Elem::U32, true),
162            Variable::UnitPosPlane => Item::scalar(Elem::U32, true),
163            Variable::ClusterRank => Item::scalar(Elem::U32, true),
164            Variable::ClusterIndexX => Item::scalar(Elem::U32, true),
165            Variable::ClusterIndexY => Item::scalar(Elem::U32, true),
166            Variable::ClusterIndexZ => Item::scalar(Elem::U32, true),
167            Variable::GlobalInputArray(_, e) => *e,
168            Variable::GlobalOutputArray(_, e) => *e,
169            Variable::LocalArray(_, e, _) => *e,
170            Variable::SharedMemory(_, e, _) => *e,
171            Variable::ConstantArray(_, e, _) => *e,
172            Variable::LocalMut { item, .. } => *item,
173            Variable::LocalConst { item, .. } => *item,
174            Variable::Named { item, .. } => *item,
175            Variable::Slice { item, .. } => *item,
176            Variable::ConstantScalar(_, e) => Item::scalar(*e, false),
177            Variable::GlobalScalar { elem, .. } => Item::scalar(*elem, false),
178            Variable::WmmaFragment { frag, .. } => Item::scalar(frag.elem, false),
179            Variable::Tmp { item, .. } => *item,
180            Variable::Pipeline { id: _, item } => *item,
181            Variable::Barrier { id: _, item, .. } => *item,
182            Variable::TensorMap(_) => unreachable!(),
183        }
184    }
185
186    fn is_const(&self) -> bool {
187        matches!(self, Variable::LocalConst { .. })
188    }
189}
190
191impl<D: Dialect> Display for Variable<D> {
192    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193        match self {
194            Variable::GlobalInputArray(id, _) => f.write_fmt(format_args!("buffer_{id}")),
195            Variable::GlobalOutputArray(id, _) => write!(f, "buffer_{id}"),
196            Variable::TensorMap(id) => write!(f, "tensor_map_{id}"),
197            Variable::LocalMut { id, .. } => f.write_fmt(format_args!("l_mut_{id}")),
198            Variable::LocalConst { id, .. } => f.write_fmt(format_args!("l_{id}")),
199            Variable::Named { name, .. } => f.write_fmt(format_args!("{name}")),
200            Variable::Slice { id, .. } => {
201                write!(f, "slice_{id}")
202            }
203            Variable::GlobalScalar {
204                id,
205                elem,
206                in_struct,
207            } => match *in_struct {
208                true => write!(f, "scalars_{elem}.x[{id}]"),
209                false => write!(f, "scalars_{elem}[{id}]"),
210            },
211            Variable::ConstantScalar(number, elem) => match number {
212                ConstantScalarValue::Int(val, kind) => match kind {
213                    gpu::IntKind::I8 => write!(f, "{elem}({})", *val as i8),
214                    gpu::IntKind::I16 => write!(f, "{elem}({})", *val as i16),
215                    gpu::IntKind::I32 => write!(f, "{elem}({})", *val as i32),
216                    gpu::IntKind::I64 => write!(f, "{elem}({})", *val),
217                },
218                ConstantScalarValue::Float(val, kind) => match kind {
219                    gpu::FloatKind::F16 => {
220                        write!(f, "{elem}({:?})", half::f16::from_f64(*val))
221                    }
222                    gpu::FloatKind::BF16 => {
223                        write!(f, "{elem}({:?})", half::bf16::from_f64(*val))
224                    }
225                    gpu::FloatKind::Flex32 => write!(f, "{elem}({:?})", *val as f32),
226                    gpu::FloatKind::TF32 => write!(f, "{elem}({:?})", *val as f32),
227                    gpu::FloatKind::F32 => write!(f, "{elem}({:?})", *val as f32),
228                    gpu::FloatKind::F64 => write!(f, "{elem}({:?})", *val),
229                },
230                ConstantScalarValue::UInt(val, kind) => match kind {
231                    gpu::UIntKind::U8 => write!(f, "{elem}({})", *val as u8),
232                    gpu::UIntKind::U16 => write!(f, "{elem}({})", *val as u16),
233                    gpu::UIntKind::U32 => write!(f, "{elem}({})", *val as u32),
234                    gpu::UIntKind::U64 => write!(f, "{elem}({})", *val),
235                },
236                ConstantScalarValue::Bool(val) => write!(f, "{}", val),
237            },
238            Variable::SharedMemory(number, _, _) => {
239                write!(f, "shared_memory_{number}")
240            }
241
242            Variable::AbsolutePos => D::compile_absolute_pos(f),
243            Variable::AbsolutePosBaseName => D::compile_absolute_pos_base_name(f),
244            Variable::AbsolutePosX => D::compile_absolute_pos_x(f),
245            Variable::AbsolutePosY => D::compile_absolute_pos_y(f),
246            Variable::AbsolutePosZ => D::compile_absolute_pos_z(f),
247            Variable::CubeCount => D::compile_cube_count(f),
248            Variable::CubeCountBaseName => D::compile_cube_count_base_name(f),
249            Variable::CubeCountX => D::compile_cube_count_x(f),
250            Variable::CubeCountY => D::compile_cube_count_y(f),
251            Variable::CubeCountZ => D::compile_cube_count_z(f),
252            Variable::CubeDim => D::compile_cube_dim(f),
253            Variable::CubeDimBaseName => D::compile_cube_dim_base_name(f),
254            Variable::CubeDimX => D::compile_cube_dim_x(f),
255            Variable::CubeDimY => D::compile_cube_dim_y(f),
256            Variable::CubeDimZ => D::compile_cube_dim_z(f),
257            Variable::CubePos => D::compile_cube_pos(f),
258            Variable::CubePosBaseName => D::compile_cube_pos_base_name(f),
259            Variable::CubePosX => D::compile_cube_pos_x(f),
260            Variable::CubePosY => D::compile_cube_pos_y(f),
261            Variable::CubePosZ => D::compile_cube_pos_z(f),
262            Variable::UnitPos => D::compile_unit_pos(f),
263            Variable::UnitPosBaseName => D::compile_unit_pos_base_name(f),
264            Variable::UnitPosX => D::compile_unit_pos_x(f),
265            Variable::UnitPosY => D::compile_unit_pos_y(f),
266            Variable::UnitPosZ => D::compile_unit_pos_z(f),
267            Variable::PlaneDim => D::compile_plane_dim(f),
268            Variable::PlaneDimChecked => D::compile_plane_dim_checked(f),
269            Variable::PlanePos => D::compile_plane_pos(f),
270            Variable::UnitPosPlane => D::compile_unit_pos_plane(f),
271            Variable::ClusterRank => D::compile_cluster_pos(f),
272            Variable::ClusterIndexX => D::compile_cluster_pos_x(f),
273            Variable::ClusterIndexY => D::compile_cluster_pos_y(f),
274            Variable::ClusterIndexZ => D::compile_cluster_pos_z(f),
275
276            Variable::ConstantArray(number, _, _) => f.write_fmt(format_args!("arrays_{number}")),
277            Variable::LocalArray(id, _, _) => {
278                write!(f, "l_arr_{}", id)
279            }
280            Variable::WmmaFragment { id: index, frag } => {
281                let name = match frag.ident {
282                    FragmentIdent::A => "a",
283                    FragmentIdent::B => "b",
284                    FragmentIdent::Accumulator => "acc",
285                    FragmentIdent::_Dialect(_) => "",
286                };
287                write!(f, "frag_{name}_{index}")
288            }
289            Variable::Tmp { id, .. } => write!(f, "_tmp_{id}"),
290            Variable::Pipeline { id, .. } => write!(f, "pipeline_{id}"),
291            Variable::Barrier { id, .. } => write!(f, "barrier_{id}"),
292        }
293    }
294}
295
296impl<D: Dialect> Variable<D> {
297    pub fn is_optimized(&self) -> bool {
298        self.item().is_optimized()
299    }
300
301    pub fn tmp(item: Item<D>) -> Self {
302        let inc = COUNTER_TMP_VAR.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
303
304        Variable::Tmp {
305            id: inc as Id,
306            item,
307        }
308    }
309
310    pub fn optimized_args<const N: usize>(args: [Self; N]) -> OptimizedArgs<N, D> {
311        let args_after = args.map(|a| a.optimized());
312
313        let item_reference_after = args_after[0].item();
314
315        let is_optimized = args_after
316            .iter()
317            .all(|var| var.elem() == item_reference_after.elem && var.is_optimized());
318
319        if is_optimized {
320            let vectorization_before = args
321                .iter()
322                .map(|var| var.item().vectorization)
323                .max()
324                .unwrap();
325            let vectorization_after = args_after
326                .iter()
327                .map(|var| var.item().vectorization)
328                .max()
329                .unwrap();
330
331            OptimizedArgs::new(args_after, Some(vectorization_before / vectorization_after))
332        } else {
333            OptimizedArgs::new(args, None)
334        }
335    }
336
337    pub fn optimized(&self) -> Self {
338        match self {
339            Variable::GlobalInputArray(id, item) => {
340                Variable::GlobalInputArray(*id, item.optimized())
341            }
342            Variable::GlobalOutputArray(id, item) => {
343                Variable::GlobalOutputArray(*id, item.optimized())
344            }
345            Variable::LocalMut { id, item } => Variable::LocalMut {
346                id: *id,
347                item: item.optimized(),
348            },
349            Variable::LocalConst { id, item } => Variable::LocalConst {
350                id: *id,
351                item: item.optimized(),
352            },
353            Variable::Slice { id, item } => Variable::Slice {
354                id: *id,
355                item: item.optimized(),
356            },
357            Variable::SharedMemory(id, item, size) => {
358                let before = item.vectorization;
359                let item = item.optimized();
360                let after = item.vectorization;
361                let scaling = (before / after) as u32;
362
363                Variable::SharedMemory(*id, item, size / scaling)
364            }
365            Variable::LocalArray(id, item, size) => {
366                let before = item.vectorization;
367                let item = item.optimized();
368                let after = item.vectorization;
369                let scaling = (before / after) as u32;
370
371                Variable::LocalArray(*id, item.optimized(), size / scaling)
372            }
373            _ => *self,
374        }
375    }
376
377    pub fn is_always_scalar(&self) -> bool {
378        match self {
379            Variable::AbsolutePos => true,
380            Variable::AbsolutePosBaseName => false,
381            Variable::AbsolutePosX => true,
382            Variable::AbsolutePosY => true,
383            Variable::AbsolutePosZ => true,
384            Variable::CubeCount => true,
385            Variable::CubeCountBaseName => false,
386            Variable::CubeCountX => true,
387            Variable::CubeCountY => true,
388            Variable::CubeCountZ => true,
389            Variable::CubeDim => true,
390            Variable::CubeDimBaseName => false,
391            Variable::CubeDimX => true,
392            Variable::CubeDimY => true,
393            Variable::CubeDimZ => true,
394            Variable::CubePos => true,
395            Variable::CubePosBaseName => true,
396            Variable::CubePosX => true,
397            Variable::CubePosY => true,
398            Variable::CubePosZ => true,
399            Variable::UnitPos => true,
400            Variable::UnitPosBaseName => true,
401            Variable::UnitPosPlane => true,
402            Variable::UnitPosX => true,
403            Variable::UnitPosY => true,
404            Variable::UnitPosZ => true,
405            Variable::PlaneDim => true,
406            Variable::PlaneDimChecked => true,
407            Variable::PlanePos => true,
408            Variable::ClusterRank => true,
409            Variable::ClusterIndexX => true,
410            Variable::ClusterIndexY => true,
411            Variable::ClusterIndexZ => true,
412
413            Variable::Barrier { .. } => false,
414            Variable::ConstantArray(_, _, _) => false,
415            Variable::ConstantScalar(_, _) => true,
416            Variable::GlobalInputArray(_, _) => false,
417            Variable::GlobalOutputArray(_, _) => false,
418            Variable::GlobalScalar { .. } => true,
419            Variable::LocalArray(_, _, _) => false,
420            Variable::LocalConst { .. } => false,
421            Variable::LocalMut { .. } => false,
422            Variable::Named { .. } => false,
423            Variable::Pipeline { .. } => false,
424            Variable::SharedMemory(_, _, _) => false,
425            Variable::Slice { .. } => false,
426            Variable::Tmp { .. } => false,
427            Variable::WmmaFragment { .. } => false,
428            Variable::TensorMap { .. } => false,
429        }
430    }
431
432    pub fn index(&self, index: usize) -> IndexedVariable<D> {
433        IndexedVariable {
434            var: *self,
435            index,
436            optimized: self.is_optimized(),
437        }
438    }
439
440    pub fn const_qualifier(&self) -> &str {
441        if self.is_const() { " const" } else { "" }
442    }
443
444    pub fn id(&self) -> Option<Id> {
445        match self {
446            Variable::GlobalInputArray(id, ..) => Some(*id),
447            Variable::GlobalOutputArray(id, ..) => Some(*id),
448            Variable::GlobalScalar { id, .. } => Some(*id),
449            Variable::ConstantArray(id, ..) => Some(*id),
450            Variable::LocalMut { id, .. } => Some(*id),
451            Variable::LocalConst { id, .. } => Some(*id),
452            Variable::Slice { id, .. } => Some(*id),
453            Variable::SharedMemory(id, ..) => Some(*id),
454            Variable::LocalArray(id, ..) => Some(*id),
455            Variable::WmmaFragment { id, .. } => Some(*id),
456            Variable::Pipeline { id, .. } => Some(*id),
457            Variable::Barrier { id, .. } => Some(*id),
458            Variable::Tmp { id, .. } => Some(*id),
459            _ => None,
460        }
461    }
462
463    /// Format variable for a pointer argument. Slices and buffers are already pointers, so we
464    /// just leave them as is to avoid accidental double pointers
465    pub fn fmt_ptr(&self) -> String {
466        match self {
467            Variable::Slice { .. }
468            | Variable::GlobalInputArray(_, _)
469            | Variable::GlobalOutputArray(_, _) => format!("{self}"),
470            _ => format!("&{self}"),
471        }
472    }
473}
474
475impl<D: Dialect> FmtLeft for Variable<D> {
476    fn fmt_left(&self) -> String {
477        match self {
478            Self::LocalConst { item, .. } => format!("const {item} {self}"),
479            Variable::Tmp { item, .. } => format!("{item} {self}"),
480            var => format!("{var}"),
481        }
482    }
483}
484
485#[derive(Debug, Clone)]
486pub struct IndexedVariable<D: Dialect> {
487    var: Variable<D>,
488    optimized: bool,
489    index: usize,
490}
491
492impl<D: Dialect> Component<D> for IndexedVariable<D> {
493    fn item(&self) -> Item<D> {
494        self.var.item()
495    }
496
497    fn index(&self, index: usize) -> IndexedVariable<D> {
498        self.var.index(index)
499    }
500
501    fn is_const(&self) -> bool {
502        matches!(self.var, Variable::LocalConst { .. })
503    }
504}
505
506impl<D: Dialect> Display for IndexedVariable<D> {
507    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
508        let var = &self.var;
509        let ref_ = matches!(var, Variable::LocalConst { .. })
510            .then_some("const&")
511            .unwrap_or("&");
512
513        if self.var.item().vectorization > 1 {
514            if self.optimized {
515                let item = self.var.item();
516                let addr_space = D::address_space_for_variable(&self.var);
517                write!(
518                    f,
519                    "(reinterpret_cast<{addr_space}{item} {ref_}>({var})).i_{}",
520                    self.index
521                )
522            } else {
523                write!(f, "{var}.i_{}", self.index)
524            }
525        } else if self.optimized {
526            let item = self.var.item();
527            let addr_space = D::address_space_for_variable(&self.var);
528            write!(f, "reinterpret_cast<{addr_space}{item} {ref_}>({var})")
529        } else {
530            write!(f, "{var}")
531        }
532    }
533}
534
535impl<D: Dialect> FmtLeft for IndexedVariable<D> {
536    fn fmt_left(&self) -> String {
537        match self.var {
538            Variable::LocalConst { item, .. } => format!("const {item} {self}"),
539            Variable::Tmp { item, .. } => format!("{item} {self}"),
540            _ => format!("{self}"),
541        }
542    }
543}
544
545impl FmtLeft for &String {
546    fn fmt_left(&self) -> String {
547        self.to_string()
548    }
549}