1use cubecl_core::ir::{BarrierLevel, ConstantValue, Id};
2use std::fmt::{Display, Formatter};
3
4use super::{COUNTER_TMP_VAR, Dialect, Elem, Fragment, FragmentIdent, Item};
5
6pub trait Component<D: Dialect>: Display + FmtLeft {
7 fn item(&self) -> Item<D>;
8 fn is_const(&self) -> bool;
9 fn index(&self, index: usize) -> IndexedVariable<D>;
10 fn elem(&self) -> Elem<D> {
11 *self.item().elem()
12 }
13}
14
15pub trait FmtLeft: Display {
16 fn fmt_left(&self) -> String;
17}
18
19#[derive(new)]
20pub struct OptimizedArgs<const N: usize, D: Dialect> {
21 pub args: [Variable<D>; N],
22 pub optimization_factor: Option<usize>,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq)]
26pub enum Variable<D: Dialect> {
27 AbsolutePos(Elem<D>),
28 AbsolutePosBaseName, AbsolutePosX,
30 AbsolutePosY,
31 AbsolutePosZ,
32 UnitPos,
33 UnitPosBaseName, UnitPosX,
35 UnitPosY,
36 UnitPosZ,
37 CubePos(Elem<D>),
38 CubePosBaseName, CubePosX,
40 CubePosY,
41 CubePosZ,
42 CubeDim,
43 CubeDimBaseName, CubeDimX,
45 CubeDimY,
46 CubeDimZ,
47 CubeCount(Elem<D>),
48 CubeCountBaseName, CubeCountX,
50 CubeCountY,
51 CubeCountZ,
52 PlaneDim,
53 PlaneDimChecked,
54 PlanePos,
55 UnitPosPlane,
56 ClusterRank,
57 ClusterIndexX,
58 ClusterIndexY,
59 ClusterIndexZ,
60 GlobalInputArray(Id, Item<D>),
61 GlobalOutputArray(Id, Item<D>),
62 GlobalScalar {
63 id: Id,
64 elem: Elem<D>,
65 in_struct: bool,
66 },
67 ConstantArray(Id, Item<D>, usize),
68 Constant(ConstantValue, Item<D>),
69 TensorMap(Id),
70 LocalMut {
71 id: Id,
72 item: Item<D>,
73 },
74 LocalConst {
75 id: Id,
76 item: Item<D>,
77 },
78 Named {
79 name: &'static str,
80 item: Item<D>,
81 },
82 Slice {
83 id: Id,
84 item: Item<D>,
85 },
86 SharedArray(Id, Item<D>, usize),
87 Shared(Id, Item<D>),
88 LocalArray(Id, Item<D>, usize),
89 WmmaFragment {
90 id: Id,
91 frag: Fragment<D>,
92 },
93 Pipeline {
94 id: Id,
95 },
96 Barrier {
97 id: Id,
98 level: BarrierLevel,
99 },
100 BarrierToken {
101 id: Id,
102 level: BarrierLevel,
103 },
104 Tmp {
105 id: Id,
106 item: Item<D>,
107 is_declared: bool,
108 is_ptr: bool,
109 is_const: bool,
110 },
111}
112
113impl<D: Dialect> Component<D> for Variable<D> {
114 fn index(&self, index: usize) -> IndexedVariable<D> {
115 self.index(index)
116 }
117
118 fn item(&self) -> Item<D> {
119 match self {
120 Variable::AbsolutePos(elem) => Item::scalar(*elem, true),
121 Variable::AbsolutePosBaseName => Item {
122 elem: Elem::U32,
123 vectorization: 3,
124 native: true,
125 },
126 Variable::AbsolutePosX => Item::scalar(Elem::U32, true),
127 Variable::AbsolutePosY => Item::scalar(Elem::U32, true),
128 Variable::AbsolutePosZ => Item::scalar(Elem::U32, true),
129 Variable::CubeCount(elem) => Item::scalar(*elem, true),
130 Variable::CubeCountBaseName => Item {
131 elem: Elem::U32,
132 vectorization: 3,
133 native: true,
134 },
135 Variable::CubeCountX => Item::scalar(Elem::U32, true),
136 Variable::CubeCountY => Item::scalar(Elem::U32, true),
137 Variable::CubeCountZ => Item::scalar(Elem::U32, true),
138 Variable::CubeDimBaseName => Item {
139 elem: Elem::U32,
140 vectorization: 3,
141 native: true,
142 },
143 Variable::CubeDim => Item::scalar(Elem::U32, true),
144 Variable::CubeDimX => Item::scalar(Elem::U32, true),
145 Variable::CubeDimY => Item::scalar(Elem::U32, true),
146 Variable::CubeDimZ => Item::scalar(Elem::U32, true),
147 Variable::CubePos(elem) => Item::scalar(*elem, true),
148 Variable::CubePosBaseName => Item {
149 elem: Elem::U32,
150 vectorization: 3,
151 native: true,
152 },
153 Variable::CubePosX => Item::scalar(Elem::U32, true),
154 Variable::CubePosY => Item::scalar(Elem::U32, true),
155 Variable::CubePosZ => Item::scalar(Elem::U32, true),
156 Variable::UnitPos => Item::scalar(Elem::U32, true),
157 Variable::UnitPosBaseName => Item {
158 elem: Elem::U32,
159 vectorization: 3,
160 native: true,
161 },
162 Variable::UnitPosX => Item::scalar(Elem::U32, true),
163 Variable::UnitPosY => Item::scalar(Elem::U32, true),
164 Variable::UnitPosZ => Item::scalar(Elem::U32, true),
165 Variable::PlaneDim => Item::scalar(Elem::U32, true),
166 Variable::PlaneDimChecked => Item::scalar(Elem::U32, true),
167 Variable::PlanePos => Item::scalar(Elem::U32, true),
168 Variable::UnitPosPlane => Item::scalar(Elem::U32, true),
169 Variable::ClusterRank => Item::scalar(Elem::U32, true),
170 Variable::ClusterIndexX => Item::scalar(Elem::U32, true),
171 Variable::ClusterIndexY => Item::scalar(Elem::U32, true),
172 Variable::ClusterIndexZ => Item::scalar(Elem::U32, true),
173 Variable::GlobalInputArray(_, e) => *e,
174 Variable::GlobalOutputArray(_, e) => *e,
175 Variable::LocalArray(_, e, _) => *e,
176 Variable::SharedArray(_, e, _) => *e,
177 Variable::Shared(_, e) => *e,
178 Variable::ConstantArray(_, e, _) => *e,
179 Variable::LocalMut { item, .. } => *item,
180 Variable::LocalConst { item, .. } => *item,
181 Variable::Named { item, .. } => *item,
182 Variable::Slice { item, .. } => *item,
183 Variable::Constant(_, e) => *e,
184 Variable::GlobalScalar { elem, .. } => Item::scalar(*elem, false),
185 Variable::WmmaFragment { frag, .. } => Item::scalar(frag.elem, false),
186 Variable::Tmp { item, .. } => *item,
187 Variable::Pipeline { .. }
188 | Variable::Barrier { .. }
189 | Variable::BarrierToken { .. } => Item::new(Elem::Bool, 1, false),
190 Variable::TensorMap(_) => unreachable!(),
191 }
192 }
193
194 fn is_const(&self) -> bool {
195 if let Variable::Tmp { is_const, .. } = self {
196 return *is_const;
197 }
198
199 matches!(
200 self,
201 Variable::LocalConst { .. } | Variable::GlobalInputArray { .. }
202 )
203 }
204}
205
206impl<D: Dialect> Display for Variable<D> {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 match self {
209 Variable::GlobalInputArray(id, _) => f.write_fmt(format_args!("buffer_{id}")),
210 Variable::GlobalOutputArray(id, _) => write!(f, "buffer_{id}"),
211 Variable::TensorMap(id) => write!(f, "tensor_map_{id}"),
212 Variable::LocalMut { id, .. } => f.write_fmt(format_args!("l_mut_{id}")),
213 Variable::LocalConst { id, .. } => f.write_fmt(format_args!("l_{id}")),
214 Variable::Named { name, .. } => f.write_fmt(format_args!("{name}")),
215 Variable::Slice { id, .. } => {
216 write!(f, "slice_{id}")
217 }
218 Variable::GlobalScalar {
219 id,
220 elem,
221 in_struct,
222 } => match *in_struct {
223 true => write!(f, "scalars_{elem}.x[{id}]"),
224 false => write!(f, "scalars_{elem}[{id}]"),
225 },
226 Variable::Constant(number, item) if item.vectorization <= 1 => {
227 write!(f, "{item}({number})")
228 }
229 Variable::Constant(number, item) => {
230 let values = (0..item.vectorization)
231 .map(|_| format!("{}({number})", item.elem()))
232 .collect::<Vec<_>>();
233 write!(f, "{item} {{ {} }}", values.join(","))
234 }
235 Variable::SharedArray(number, _, _) | Variable::Shared(number, _) => {
236 write!(f, "shared_memory_{number}")
237 }
238
239 Variable::AbsolutePos(_) => D::compile_absolute_pos(f),
240 Variable::AbsolutePosBaseName => D::compile_absolute_pos_base_name(f),
241 Variable::AbsolutePosX => D::compile_absolute_pos_x(f),
242 Variable::AbsolutePosY => D::compile_absolute_pos_y(f),
243 Variable::AbsolutePosZ => D::compile_absolute_pos_z(f),
244 Variable::CubeCount(_) => D::compile_cube_count(f),
245 Variable::CubeCountBaseName => D::compile_cube_count_base_name(f),
246 Variable::CubeCountX => D::compile_cube_count_x(f),
247 Variable::CubeCountY => D::compile_cube_count_y(f),
248 Variable::CubeCountZ => D::compile_cube_count_z(f),
249 Variable::CubeDim => D::compile_cube_dim(f),
250 Variable::CubeDimBaseName => D::compile_cube_dim_base_name(f),
251 Variable::CubeDimX => D::compile_cube_dim_x(f),
252 Variable::CubeDimY => D::compile_cube_dim_y(f),
253 Variable::CubeDimZ => D::compile_cube_dim_z(f),
254 Variable::CubePos(_) => D::compile_cube_pos(f),
255 Variable::CubePosBaseName => D::compile_cube_pos_base_name(f),
256 Variable::CubePosX => D::compile_cube_pos_x(f),
257 Variable::CubePosY => D::compile_cube_pos_y(f),
258 Variable::CubePosZ => D::compile_cube_pos_z(f),
259 Variable::UnitPos => D::compile_unit_pos(f),
260 Variable::UnitPosBaseName => D::compile_unit_pos_base_name(f),
261 Variable::UnitPosX => D::compile_unit_pos_x(f),
262 Variable::UnitPosY => D::compile_unit_pos_y(f),
263 Variable::UnitPosZ => D::compile_unit_pos_z(f),
264 Variable::PlaneDim => D::compile_plane_dim(f),
265 Variable::PlaneDimChecked => D::compile_plane_dim_checked(f),
266 Variable::PlanePos => D::compile_plane_pos(f),
267 Variable::UnitPosPlane => D::compile_unit_pos_plane(f),
268 Variable::ClusterRank => D::compile_cluster_pos(f),
269 Variable::ClusterIndexX => D::compile_cluster_pos_x(f),
270 Variable::ClusterIndexY => D::compile_cluster_pos_y(f),
271 Variable::ClusterIndexZ => D::compile_cluster_pos_z(f),
272
273 Variable::ConstantArray(number, _, _) => f.write_fmt(format_args!("arrays_{number}")),
274 Variable::LocalArray(id, _, _) => {
275 write!(f, "l_arr_{id}")
276 }
277 Variable::WmmaFragment { id: index, frag } => {
278 let name = match frag.ident {
279 FragmentIdent::A => "a",
280 FragmentIdent::B => "b",
281 FragmentIdent::Accumulator => "acc",
282 FragmentIdent::_Dialect(_) => "",
283 };
284 write!(f, "frag_{name}_{index}")
285 }
286 Variable::Tmp { id, .. } => write!(f, "_tmp_{id}"),
287 Variable::Pipeline { id, .. } => write!(f, "pipeline_{id}"),
288 Variable::Barrier { id, .. } => write!(f, "barrier_{id}"),
289 Variable::BarrierToken { id, .. } => write!(f, "barrier_{id}_token"),
290 }
291 }
292}
293
294impl<D: Dialect> Variable<D> {
295 pub fn is_optimized(&self) -> bool {
296 self.item().is_optimized()
297 }
298
299 pub fn tmp(item: Item<D>) -> Self {
303 let inc = COUNTER_TMP_VAR.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
304
305 Variable::Tmp {
306 id: inc as Id,
307 item,
308 is_declared: false,
309 is_ptr: false,
310 is_const: false,
311 }
312 }
313
314 pub fn to_const(&mut self) {
315 if let Variable::Tmp { is_const, .. } = self {
316 *is_const = true;
317 }
318 }
319
320 pub fn reinterpret_ptr(&self, f: &mut Formatter<'_>, item: Item<D>) -> Self {
322 let mut out = Self::tmp_ptr(item);
323
324 if self.is_const() {
325 out.to_const();
326 }
327
328 let elem = out.elem();
329 let qualifier = out.const_qualifier();
330 let addr_space = D::address_space_for_variable(self);
331 let out_fmt = out.fmt_left();
332
333 writeln!(
334 f,
335 "{out_fmt} = reinterpret_cast<{addr_space}{elem}{qualifier}*>({self});"
336 )
337 .unwrap();
338
339 out
340 }
341
342 pub fn tmp_ptr(item: Item<D>) -> Self {
346 let inc = COUNTER_TMP_VAR.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
347
348 Variable::Tmp {
349 id: inc as Id,
350 item,
351 is_declared: false,
352 is_ptr: true,
353 is_const: false,
354 }
355 }
356
357 pub fn tmp_declared(item: Item<D>) -> Self {
363 let inc = COUNTER_TMP_VAR.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
364
365 Variable::Tmp {
366 id: inc as Id,
367 item,
368 is_declared: true,
369 is_ptr: false,
370 is_const: false,
371 }
372 }
373
374 pub fn optimized_args<const N: usize>(args: [Self; N]) -> OptimizedArgs<N, D> {
375 let args_after = args.map(|a| a.optimized());
376
377 let item_reference_after = args_after[0].item();
378
379 let is_optimized = args_after
380 .iter()
381 .all(|var| var.elem() == item_reference_after.elem && var.is_optimized());
382
383 if is_optimized {
384 let vectorization_before = args
385 .iter()
386 .map(|var| var.item().vectorization)
387 .max()
388 .unwrap();
389 let vectorization_after = args_after
390 .iter()
391 .map(|var| var.item().vectorization)
392 .max()
393 .unwrap();
394
395 OptimizedArgs::new(args_after, Some(vectorization_before / vectorization_after))
396 } else {
397 OptimizedArgs::new(args, None)
398 }
399 }
400
401 pub fn optimized(&self) -> Self {
402 match self {
403 Variable::GlobalInputArray(id, item) => {
404 Variable::GlobalInputArray(*id, item.optimized())
405 }
406 Variable::GlobalOutputArray(id, item) => {
407 Variable::GlobalOutputArray(*id, item.optimized())
408 }
409 Variable::LocalMut { id, item } => Variable::LocalMut {
410 id: *id,
411 item: item.optimized(),
412 },
413 Variable::LocalConst { id, item } => Variable::LocalConst {
414 id: *id,
415 item: item.optimized(),
416 },
417 Variable::Slice { id, item } => Variable::Slice {
418 id: *id,
419 item: item.optimized(),
420 },
421 Variable::Tmp {
422 id,
423 item,
424 is_declared,
425 is_ptr,
426 is_const,
427 } => Variable::Tmp {
428 id: *id,
429 item: item.optimized(),
430 is_declared: *is_declared,
431 is_ptr: *is_ptr,
432 is_const: *is_const,
433 },
434 Variable::SharedArray(id, item, size) => {
435 let before = item.vectorization;
436 let item = item.optimized();
437 let after = item.vectorization;
438 let scaling = before / after;
439
440 Variable::SharedArray(*id, item, size / scaling)
441 }
442 Variable::LocalArray(id, item, size) => {
443 let before = item.vectorization;
444 let item = item.optimized();
445 let after = item.vectorization;
446 let scaling = before / after;
447
448 Variable::LocalArray(*id, item.optimized(), size / scaling)
449 }
450 _ => *self,
451 }
452 }
453
454 pub fn is_always_scalar(&self) -> bool {
455 match self {
456 Variable::AbsolutePos(_) => true,
457 Variable::AbsolutePosBaseName => false,
458 Variable::AbsolutePosX => true,
459 Variable::AbsolutePosY => true,
460 Variable::AbsolutePosZ => true,
461 Variable::CubeCount(_) => true,
462 Variable::CubeCountBaseName => false,
463 Variable::CubeCountX => true,
464 Variable::CubeCountY => true,
465 Variable::CubeCountZ => true,
466 Variable::CubeDim => true,
467 Variable::CubeDimBaseName => false,
468 Variable::CubeDimX => true,
469 Variable::CubeDimY => true,
470 Variable::CubeDimZ => true,
471 Variable::CubePos(_) => true,
472 Variable::CubePosBaseName => true,
473 Variable::CubePosX => true,
474 Variable::CubePosY => true,
475 Variable::CubePosZ => true,
476 Variable::UnitPos => true,
477 Variable::UnitPosBaseName => true,
478 Variable::UnitPosPlane => true,
479 Variable::UnitPosX => true,
480 Variable::UnitPosY => true,
481 Variable::UnitPosZ => true,
482 Variable::PlaneDim => true,
483 Variable::PlaneDimChecked => true,
484 Variable::PlanePos => true,
485 Variable::ClusterRank => true,
486 Variable::ClusterIndexX => true,
487 Variable::ClusterIndexY => true,
488 Variable::ClusterIndexZ => true,
489
490 Variable::Barrier { .. } => false,
491 Variable::BarrierToken { .. } => false,
492 Variable::ConstantArray(_, _, _) => false,
493 Variable::Constant(_, _) => true,
494 Variable::GlobalInputArray(_, _) => false,
495 Variable::GlobalOutputArray(_, _) => false,
496 Variable::GlobalScalar { .. } => true,
497 Variable::LocalArray(_, _, _) => false,
498 Variable::LocalConst { .. } => false,
499 Variable::LocalMut { .. } => false,
500 Variable::Named { .. } => false,
501 Variable::Pipeline { .. } => false,
502 Variable::SharedArray(_, _, _) => false,
503 Variable::Shared(_, _) => false,
504 Variable::Slice { .. } => false,
505 Variable::Tmp { .. } => false,
506 Variable::WmmaFragment { .. } => false,
507 Variable::TensorMap { .. } => false,
508 }
509 }
510
511 pub fn index(&self, index: usize) -> IndexedVariable<D> {
512 IndexedVariable {
513 var: *self,
514 index,
515 optimized: self.is_optimized(),
516 }
517 }
518
519 pub fn const_qualifier(&self) -> &str {
520 if self.is_const() { " const" } else { "" }
521 }
522
523 pub fn id(&self) -> Option<Id> {
524 match self {
525 Variable::GlobalInputArray(id, ..) => Some(*id),
526 Variable::GlobalOutputArray(id, ..) => Some(*id),
527 Variable::GlobalScalar { id, .. } => Some(*id),
528 Variable::ConstantArray(id, ..) => Some(*id),
529 Variable::LocalMut { id, .. } => Some(*id),
530 Variable::LocalConst { id, .. } => Some(*id),
531 Variable::Slice { id, .. } => Some(*id),
532 Variable::Shared(id, ..) => Some(*id),
533 Variable::SharedArray(id, ..) => Some(*id),
534 Variable::LocalArray(id, ..) => Some(*id),
535 Variable::WmmaFragment { id, .. } => Some(*id),
536 Variable::Pipeline { id, .. } => Some(*id),
537 Variable::Barrier { id, .. } => Some(*id),
538 Variable::Tmp { id, .. } => Some(*id),
539 _ => None,
540 }
541 }
542
543 pub fn fmt_ptr(&self) -> String {
546 match self {
547 Variable::Slice { .. }
548 | Variable::SharedArray(_, _, _)
549 | Variable::GlobalInputArray(_, _)
550 | Variable::GlobalOutputArray(_, _) => format!("{self}"),
551 _ => format!("&{self}"),
552 }
553 }
554
555 pub fn fmt_cast_to(&self, item: Item<D>) -> String {
557 if self.item() == item {
558 self.to_string()
559 } else {
560 format!("{item}({self})")
561 }
562 }
563}
564
565impl<D: Dialect> FmtLeft for Variable<D> {
566 fn fmt_left(&self) -> String {
567 match self {
568 Self::LocalConst { item, .. } => match item.elem {
569 Elem::Atomic(_) => {
570 format!("{item}* {self}")
571 }
572 _ => {
573 format!("const {item} {self}")
574 }
575 },
576 Variable::Tmp {
577 item,
578 is_declared,
579 is_ptr,
580 is_const,
581 ..
582 } => {
583 if *is_declared {
584 return format!("{self}");
585 }
586 if *is_ptr {
587 if *is_const {
588 return format!("const {item} *{self}");
589 }
590 return format!("{item} *{self}");
591 }
592
593 format!("{item} {self}")
594 }
595 var => format!("{var}"),
596 }
597 }
598}
599
600#[derive(Debug, Clone)]
601pub struct IndexedVariable<D: Dialect> {
602 var: Variable<D>,
603 optimized: bool,
604 index: usize,
605}
606
607impl<D: Dialect> Component<D> for IndexedVariable<D> {
608 fn item(&self) -> Item<D> {
609 self.var.item()
610 }
611
612 fn index(&self, index: usize) -> IndexedVariable<D> {
613 self.var.index(index)
614 }
615
616 fn is_const(&self) -> bool {
617 self.var.is_const()
618 }
619}
620
621impl<D: Dialect> Display for IndexedVariable<D> {
622 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
623 let var = &self.var;
624
625 if let Variable::Constant(value, item) = var {
626 return write!(f, "{}({value})", item.elem());
627 }
628
629 let ref_ = matches!(var, Variable::LocalConst { .. })
630 .then_some("const&")
631 .unwrap_or("&");
632
633 if self.var.item().vectorization > 1 {
634 if self.optimized {
635 let item = self.var.item();
636 let addr_space = D::address_space_for_variable(&self.var);
637 write!(
638 f,
639 "(reinterpret_cast<{addr_space}{item} {ref_}>({var})).i_{}",
640 self.index
641 )
642 } else {
643 write!(f, "{var}.i_{}", self.index)
644 }
645 } else if self.optimized {
646 let item = self.var.item();
647 let addr_space = D::address_space_for_variable(&self.var);
648 write!(f, "reinterpret_cast<{addr_space}{item} {ref_}>({var})")
649 } else {
650 write!(f, "{var}")
651 }
652 }
653}
654
655impl<D: Dialect> FmtLeft for IndexedVariable<D> {
656 fn fmt_left(&self) -> String {
657 let var = &self.var;
658 let ref_ = matches!(var, Variable::LocalConst { .. })
659 .then_some("const&")
660 .unwrap_or("&");
661
662 let name = if self.var.item().vectorization > 1 {
663 if self.optimized {
664 let item = self.var.item();
665 let addr_space = D::address_space_for_variable(&self.var);
666 format!(
667 "(reinterpret_cast<{addr_space}{item} {ref_}>({var})).i_{}",
668 self.index
669 )
670 } else {
671 format!("{var}.i_{}", self.index)
672 }
673 } else {
674 format!("{var}")
675 };
676 match var {
677 Variable::LocalConst { item, .. } => format!("const {item} {name}"),
678 Variable::Tmp { item, is_ptr, .. } => {
679 if *is_ptr {
680 format!("{item} *{name}")
681 } else {
682 format!("{item} {name}")
683 }
684 }
685 _ => name,
686 }
687 }
688}
689
690impl FmtLeft for &String {
691 fn fmt_left(&self) -> String {
692 self.to_string()
693 }
694}