1use cubecl_core::ir::{self as gpu, BarrierLevel, ConstantScalarValue, Id};
2use std::fmt::Display;
3
4use super::{COUNTER_TMP_VAR, Dialect, Elem, Fragment, FragmentIdent, Item};
5
6pub trait Component<D: Dialect>: Display + FmtLeft {
7 fn item(&self) -> Item<D>;
8 fn is_const(&self) -> bool;
9 fn index(&self, index: usize) -> IndexedVariable<D>;
10 fn elem(&self) -> Elem<D> {
11 *self.item().elem()
12 }
13}
14
15pub trait FmtLeft: Display {
16 fn fmt_left(&self) -> String;
17}
18
19#[derive(new)]
20pub struct OptimizedArgs<const N: usize, D: Dialect> {
21 pub args: [Variable<D>; N],
22 pub optimization_factor: Option<usize>,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq)]
26pub enum Variable<D: Dialect> {
27 AbsolutePos,
28 AbsolutePosBaseName, AbsolutePosX,
30 AbsolutePosY,
31 AbsolutePosZ,
32 UnitPos,
33 UnitPosBaseName, UnitPosX,
35 UnitPosY,
36 UnitPosZ,
37 CubePos,
38 CubePosBaseName, CubePosX,
40 CubePosY,
41 CubePosZ,
42 CubeDim,
43 CubeDimBaseName, CubeDimX,
45 CubeDimY,
46 CubeDimZ,
47 CubeCount,
48 CubeCountBaseName, CubeCountX,
50 CubeCountY,
51 CubeCountZ,
52 PlaneDim,
53 PlaneDimChecked,
54 PlanePos,
55 UnitPosPlane,
56 ClusterRank,
57 ClusterIndexX,
58 ClusterIndexY,
59 ClusterIndexZ,
60 GlobalInputArray(Id, Item<D>),
61 GlobalOutputArray(Id, Item<D>),
62 GlobalScalar {
63 id: Id,
64 elem: Elem<D>,
65 in_struct: bool,
66 },
67 ConstantArray(Id, Item<D>, u32),
68 ConstantScalar(ConstantScalarValue, Elem<D>),
69 TensorMap(Id),
70 LocalMut {
71 id: Id,
72 item: Item<D>,
73 },
74 LocalConst {
75 id: Id,
76 item: Item<D>,
77 },
78 Named {
79 name: &'static str,
80 item: Item<D>,
81 },
82 Slice {
83 id: Id,
84 item: Item<D>,
85 },
86 SharedMemory(Id, Item<D>, u32),
87 LocalArray(Id, Item<D>, u32),
88 WmmaFragment {
89 id: Id,
90 frag: Fragment<D>,
91 },
92 Pipeline {
93 id: Id,
94 item: Item<D>,
95 },
96 Barrier {
97 id: Id,
98 item: Item<D>,
99 level: BarrierLevel,
100 },
101 Tmp {
102 id: Id,
103 item: Item<D>,
104 },
105}
106
107impl<D: Dialect> Component<D> for Variable<D> {
108 fn index(&self, index: usize) -> IndexedVariable<D> {
109 self.index(index)
110 }
111
112 fn item(&self) -> Item<D> {
113 match self {
114 Variable::AbsolutePos => Item::scalar(Elem::U32, true),
115 Variable::AbsolutePosBaseName => Item {
116 elem: Elem::U32,
117 vectorization: 3,
118 native: true,
119 },
120 Variable::AbsolutePosX => Item::scalar(Elem::U32, true),
121 Variable::AbsolutePosY => Item::scalar(Elem::U32, true),
122 Variable::AbsolutePosZ => Item::scalar(Elem::U32, true),
123 Variable::CubeCount => Item::scalar(Elem::U32, true),
124 Variable::CubeCountBaseName => Item {
125 elem: Elem::U32,
126 vectorization: 3,
127 native: true,
128 },
129 Variable::CubeCountX => Item::scalar(Elem::U32, true),
130 Variable::CubeCountY => Item::scalar(Elem::U32, true),
131 Variable::CubeCountZ => Item::scalar(Elem::U32, true),
132 Variable::CubeDimBaseName => Item {
133 elem: Elem::U32,
134 vectorization: 3,
135 native: true,
136 },
137 Variable::CubeDim => Item::scalar(Elem::U32, true),
138 Variable::CubeDimX => Item::scalar(Elem::U32, true),
139 Variable::CubeDimY => Item::scalar(Elem::U32, true),
140 Variable::CubeDimZ => Item::scalar(Elem::U32, true),
141 Variable::CubePos => Item::scalar(Elem::U32, true),
142 Variable::CubePosBaseName => Item {
143 elem: Elem::U32,
144 vectorization: 3,
145 native: true,
146 },
147 Variable::CubePosX => Item::scalar(Elem::U32, true),
148 Variable::CubePosY => Item::scalar(Elem::U32, true),
149 Variable::CubePosZ => Item::scalar(Elem::U32, true),
150 Variable::UnitPos => Item::scalar(Elem::U32, true),
151 Variable::UnitPosBaseName => Item {
152 elem: Elem::U32,
153 vectorization: 3,
154 native: true,
155 },
156 Variable::UnitPosX => Item::scalar(Elem::U32, true),
157 Variable::UnitPosY => Item::scalar(Elem::U32, true),
158 Variable::UnitPosZ => Item::scalar(Elem::U32, true),
159 Variable::PlaneDim => Item::scalar(Elem::U32, true),
160 Variable::PlaneDimChecked => Item::scalar(Elem::U32, true),
161 Variable::PlanePos => Item::scalar(Elem::U32, true),
162 Variable::UnitPosPlane => Item::scalar(Elem::U32, true),
163 Variable::ClusterRank => Item::scalar(Elem::U32, true),
164 Variable::ClusterIndexX => Item::scalar(Elem::U32, true),
165 Variable::ClusterIndexY => Item::scalar(Elem::U32, true),
166 Variable::ClusterIndexZ => Item::scalar(Elem::U32, true),
167 Variable::GlobalInputArray(_, e) => *e,
168 Variable::GlobalOutputArray(_, e) => *e,
169 Variable::LocalArray(_, e, _) => *e,
170 Variable::SharedMemory(_, e, _) => *e,
171 Variable::ConstantArray(_, e, _) => *e,
172 Variable::LocalMut { item, .. } => *item,
173 Variable::LocalConst { item, .. } => *item,
174 Variable::Named { item, .. } => *item,
175 Variable::Slice { item, .. } => *item,
176 Variable::ConstantScalar(_, e) => Item::scalar(*e, false),
177 Variable::GlobalScalar { elem, .. } => Item::scalar(*elem, false),
178 Variable::WmmaFragment { frag, .. } => Item::scalar(frag.elem, false),
179 Variable::Tmp { item, .. } => *item,
180 Variable::Pipeline { id: _, item } => *item,
181 Variable::Barrier { id: _, item, .. } => *item,
182 Variable::TensorMap(_) => unreachable!(),
183 }
184 }
185
186 fn is_const(&self) -> bool {
187 matches!(self, Variable::LocalConst { .. })
188 }
189}
190
191impl<D: Dialect> Display for Variable<D> {
192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193 match self {
194 Variable::GlobalInputArray(id, _) => f.write_fmt(format_args!("buffer_{id}")),
195 Variable::GlobalOutputArray(id, _) => write!(f, "buffer_{id}"),
196 Variable::TensorMap(id) => write!(f, "tensor_map_{id}"),
197 Variable::LocalMut { id, .. } => f.write_fmt(format_args!("l_mut_{id}")),
198 Variable::LocalConst { id, .. } => f.write_fmt(format_args!("l_{id}")),
199 Variable::Named { name, .. } => f.write_fmt(format_args!("{name}")),
200 Variable::Slice { id, .. } => {
201 write!(f, "slice_{id}")
202 }
203 Variable::GlobalScalar {
204 id,
205 elem,
206 in_struct,
207 } => match *in_struct {
208 true => write!(f, "scalars_{elem}.x[{id}]"),
209 false => write!(f, "scalars_{elem}[{id}]"),
210 },
211 Variable::ConstantScalar(number, elem) => match number {
212 ConstantScalarValue::Int(val, kind) => match kind {
213 gpu::IntKind::I8 => write!(f, "{elem}({})", *val as i8),
214 gpu::IntKind::I16 => write!(f, "{elem}({})", *val as i16),
215 gpu::IntKind::I32 => write!(f, "{elem}({})", *val as i32),
216 gpu::IntKind::I64 => write!(f, "{elem}({})", *val),
217 },
218 ConstantScalarValue::Float(val, kind) => match kind {
219 gpu::FloatKind::F16 => {
220 write!(f, "{elem}({:?})", half::f16::from_f64(*val))
221 }
222 gpu::FloatKind::BF16 => {
223 write!(f, "{elem}({:?})", half::bf16::from_f64(*val))
224 }
225 gpu::FloatKind::Flex32 => write!(f, "{elem}({:?})", *val as f32),
226 gpu::FloatKind::TF32 => write!(f, "{elem}({:?})", *val as f32),
227 gpu::FloatKind::F32 => write!(f, "{elem}({:?})", *val as f32),
228 gpu::FloatKind::F64 => write!(f, "{elem}({:?})", *val),
229 },
230 ConstantScalarValue::UInt(val, kind) => match kind {
231 gpu::UIntKind::U8 => write!(f, "{elem}({})", *val as u8),
232 gpu::UIntKind::U16 => write!(f, "{elem}({})", *val as u16),
233 gpu::UIntKind::U32 => write!(f, "{elem}({})", *val as u32),
234 gpu::UIntKind::U64 => write!(f, "{elem}({})", *val),
235 },
236 ConstantScalarValue::Bool(val) => write!(f, "{}", val),
237 },
238 Variable::SharedMemory(number, _, _) => {
239 write!(f, "shared_memory_{number}")
240 }
241
242 Variable::AbsolutePos => D::compile_absolute_pos(f),
243 Variable::AbsolutePosBaseName => D::compile_absolute_pos_base_name(f),
244 Variable::AbsolutePosX => D::compile_absolute_pos_x(f),
245 Variable::AbsolutePosY => D::compile_absolute_pos_y(f),
246 Variable::AbsolutePosZ => D::compile_absolute_pos_z(f),
247 Variable::CubeCount => D::compile_cube_count(f),
248 Variable::CubeCountBaseName => D::compile_cube_count_base_name(f),
249 Variable::CubeCountX => D::compile_cube_count_x(f),
250 Variable::CubeCountY => D::compile_cube_count_y(f),
251 Variable::CubeCountZ => D::compile_cube_count_z(f),
252 Variable::CubeDim => D::compile_cube_dim(f),
253 Variable::CubeDimBaseName => D::compile_cube_dim_base_name(f),
254 Variable::CubeDimX => D::compile_cube_dim_x(f),
255 Variable::CubeDimY => D::compile_cube_dim_y(f),
256 Variable::CubeDimZ => D::compile_cube_dim_z(f),
257 Variable::CubePos => D::compile_cube_pos(f),
258 Variable::CubePosBaseName => D::compile_cube_pos_base_name(f),
259 Variable::CubePosX => D::compile_cube_pos_x(f),
260 Variable::CubePosY => D::compile_cube_pos_y(f),
261 Variable::CubePosZ => D::compile_cube_pos_z(f),
262 Variable::UnitPos => D::compile_unit_pos(f),
263 Variable::UnitPosBaseName => D::compile_unit_pos_base_name(f),
264 Variable::UnitPosX => D::compile_unit_pos_x(f),
265 Variable::UnitPosY => D::compile_unit_pos_y(f),
266 Variable::UnitPosZ => D::compile_unit_pos_z(f),
267 Variable::PlaneDim => D::compile_plane_dim(f),
268 Variable::PlaneDimChecked => D::compile_plane_dim_checked(f),
269 Variable::PlanePos => D::compile_plane_pos(f),
270 Variable::UnitPosPlane => D::compile_unit_pos_plane(f),
271 Variable::ClusterRank => D::compile_cluster_pos(f),
272 Variable::ClusterIndexX => D::compile_cluster_pos_x(f),
273 Variable::ClusterIndexY => D::compile_cluster_pos_y(f),
274 Variable::ClusterIndexZ => D::compile_cluster_pos_z(f),
275
276 Variable::ConstantArray(number, _, _) => f.write_fmt(format_args!("arrays_{number}")),
277 Variable::LocalArray(id, _, _) => {
278 write!(f, "l_arr_{}", id)
279 }
280 Variable::WmmaFragment { id: index, frag } => {
281 let name = match frag.ident {
282 FragmentIdent::A => "a",
283 FragmentIdent::B => "b",
284 FragmentIdent::Accumulator => "acc",
285 FragmentIdent::_Dialect(_) => "",
286 };
287 write!(f, "frag_{name}_{index}")
288 }
289 Variable::Tmp { id, .. } => write!(f, "_tmp_{id}"),
290 Variable::Pipeline { id, .. } => write!(f, "pipeline_{id}"),
291 Variable::Barrier { id, .. } => write!(f, "barrier_{id}"),
292 }
293 }
294}
295
296impl<D: Dialect> Variable<D> {
297 pub fn is_optimized(&self) -> bool {
298 self.item().is_optimized()
299 }
300
301 pub fn tmp(item: Item<D>) -> Self {
302 let inc = COUNTER_TMP_VAR.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
303
304 Variable::Tmp {
305 id: inc as Id,
306 item,
307 }
308 }
309
310 pub fn optimized_args<const N: usize>(args: [Self; N]) -> OptimizedArgs<N, D> {
311 let args_after = args.map(|a| a.optimized());
312
313 let item_reference_after = args_after[0].item();
314
315 let is_optimized = args_after
316 .iter()
317 .all(|var| var.elem() == item_reference_after.elem && var.is_optimized());
318
319 if is_optimized {
320 let vectorization_before = args
321 .iter()
322 .map(|var| var.item().vectorization)
323 .max()
324 .unwrap();
325 let vectorization_after = args_after
326 .iter()
327 .map(|var| var.item().vectorization)
328 .max()
329 .unwrap();
330
331 OptimizedArgs::new(args_after, Some(vectorization_before / vectorization_after))
332 } else {
333 OptimizedArgs::new(args, None)
334 }
335 }
336
337 pub fn optimized(&self) -> Self {
338 match self {
339 Variable::GlobalInputArray(id, item) => {
340 Variable::GlobalInputArray(*id, item.optimized())
341 }
342 Variable::GlobalOutputArray(id, item) => {
343 Variable::GlobalOutputArray(*id, item.optimized())
344 }
345 Variable::LocalMut { id, item } => Variable::LocalMut {
346 id: *id,
347 item: item.optimized(),
348 },
349 Variable::LocalConst { id, item } => Variable::LocalConst {
350 id: *id,
351 item: item.optimized(),
352 },
353 Variable::Slice { id, item } => Variable::Slice {
354 id: *id,
355 item: item.optimized(),
356 },
357 Variable::SharedMemory(id, item, size) => {
358 let before = item.vectorization;
359 let item = item.optimized();
360 let after = item.vectorization;
361 let scaling = (before / after) as u32;
362
363 Variable::SharedMemory(*id, item, size / scaling)
364 }
365 Variable::LocalArray(id, item, size) => {
366 let before = item.vectorization;
367 let item = item.optimized();
368 let after = item.vectorization;
369 let scaling = (before / after) as u32;
370
371 Variable::LocalArray(*id, item.optimized(), size / scaling)
372 }
373 _ => *self,
374 }
375 }
376
377 pub fn is_always_scalar(&self) -> bool {
378 match self {
379 Variable::AbsolutePos => true,
380 Variable::AbsolutePosBaseName => false,
381 Variable::AbsolutePosX => true,
382 Variable::AbsolutePosY => true,
383 Variable::AbsolutePosZ => true,
384 Variable::CubeCount => true,
385 Variable::CubeCountBaseName => false,
386 Variable::CubeCountX => true,
387 Variable::CubeCountY => true,
388 Variable::CubeCountZ => true,
389 Variable::CubeDim => true,
390 Variable::CubeDimBaseName => false,
391 Variable::CubeDimX => true,
392 Variable::CubeDimY => true,
393 Variable::CubeDimZ => true,
394 Variable::CubePos => true,
395 Variable::CubePosBaseName => true,
396 Variable::CubePosX => true,
397 Variable::CubePosY => true,
398 Variable::CubePosZ => true,
399 Variable::UnitPos => true,
400 Variable::UnitPosBaseName => true,
401 Variable::UnitPosPlane => true,
402 Variable::UnitPosX => true,
403 Variable::UnitPosY => true,
404 Variable::UnitPosZ => true,
405 Variable::PlaneDim => true,
406 Variable::PlaneDimChecked => true,
407 Variable::PlanePos => true,
408 Variable::ClusterRank => true,
409 Variable::ClusterIndexX => true,
410 Variable::ClusterIndexY => true,
411 Variable::ClusterIndexZ => true,
412
413 Variable::Barrier { .. } => false,
414 Variable::ConstantArray(_, _, _) => false,
415 Variable::ConstantScalar(_, _) => true,
416 Variable::GlobalInputArray(_, _) => false,
417 Variable::GlobalOutputArray(_, _) => false,
418 Variable::GlobalScalar { .. } => true,
419 Variable::LocalArray(_, _, _) => false,
420 Variable::LocalConst { .. } => false,
421 Variable::LocalMut { .. } => false,
422 Variable::Named { .. } => false,
423 Variable::Pipeline { .. } => false,
424 Variable::SharedMemory(_, _, _) => false,
425 Variable::Slice { .. } => false,
426 Variable::Tmp { .. } => false,
427 Variable::WmmaFragment { .. } => false,
428 Variable::TensorMap { .. } => false,
429 }
430 }
431
432 pub fn index(&self, index: usize) -> IndexedVariable<D> {
433 IndexedVariable {
434 var: *self,
435 index,
436 optimized: self.is_optimized(),
437 }
438 }
439
440 pub fn const_qualifier(&self) -> &str {
441 if self.is_const() { " const" } else { "" }
442 }
443
444 pub fn id(&self) -> Option<Id> {
445 match self {
446 Variable::GlobalInputArray(id, ..) => Some(*id),
447 Variable::GlobalOutputArray(id, ..) => Some(*id),
448 Variable::GlobalScalar { id, .. } => Some(*id),
449 Variable::ConstantArray(id, ..) => Some(*id),
450 Variable::LocalMut { id, .. } => Some(*id),
451 Variable::LocalConst { id, .. } => Some(*id),
452 Variable::Slice { id, .. } => Some(*id),
453 Variable::SharedMemory(id, ..) => Some(*id),
454 Variable::LocalArray(id, ..) => Some(*id),
455 Variable::WmmaFragment { id, .. } => Some(*id),
456 Variable::Pipeline { id, .. } => Some(*id),
457 Variable::Barrier { id, .. } => Some(*id),
458 Variable::Tmp { id, .. } => Some(*id),
459 _ => None,
460 }
461 }
462
463 pub fn fmt_ptr(&self) -> String {
466 match self {
467 Variable::Slice { .. }
468 | Variable::GlobalInputArray(_, _)
469 | Variable::GlobalOutputArray(_, _) => format!("{self}"),
470 _ => format!("&{self}"),
471 }
472 }
473}
474
475impl<D: Dialect> FmtLeft for Variable<D> {
476 fn fmt_left(&self) -> String {
477 match self {
478 Self::LocalConst { item, .. } => format!("const {item} {self}"),
479 Variable::Tmp { item, .. } => format!("{item} {self}"),
480 var => format!("{var}"),
481 }
482 }
483}
484
485#[derive(Debug, Clone)]
486pub struct IndexedVariable<D: Dialect> {
487 var: Variable<D>,
488 optimized: bool,
489 index: usize,
490}
491
492impl<D: Dialect> Component<D> for IndexedVariable<D> {
493 fn item(&self) -> Item<D> {
494 self.var.item()
495 }
496
497 fn index(&self, index: usize) -> IndexedVariable<D> {
498 self.var.index(index)
499 }
500
501 fn is_const(&self) -> bool {
502 matches!(self.var, Variable::LocalConst { .. })
503 }
504}
505
506impl<D: Dialect> Display for IndexedVariable<D> {
507 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
508 let var = &self.var;
509 let ref_ = matches!(var, Variable::LocalConst { .. })
510 .then_some("const&")
511 .unwrap_or("&");
512
513 if self.var.item().vectorization > 1 {
514 if self.optimized {
515 let item = self.var.item();
516 let addr_space = D::address_space_for_variable(&self.var);
517 write!(
518 f,
519 "(reinterpret_cast<{addr_space}{item} {ref_}>({var})).i_{}",
520 self.index
521 )
522 } else {
523 write!(f, "{var}.i_{}", self.index)
524 }
525 } else if self.optimized {
526 let item = self.var.item();
527 let addr_space = D::address_space_for_variable(&self.var);
528 write!(f, "reinterpret_cast<{addr_space}{item} {ref_}>({var})")
529 } else {
530 write!(f, "{var}")
531 }
532 }
533}
534
535impl<D: Dialect> FmtLeft for IndexedVariable<D> {
536 fn fmt_left(&self) -> String {
537 match self.var {
538 Variable::LocalConst { item, .. } => format!("const {item} {self}"),
539 Variable::Tmp { item, .. } => format!("{item} {self}"),
540 _ => format!("{self}"),
541 }
542 }
543}
544
545impl FmtLeft for &String {
546 fn fmt_left(&self) -> String {
547 self.to_string()
548 }
549}