1use cubecl_core::{
2 ir::{self as gpu, ConstantScalarValue, Id},
3 tf32,
4};
5use half::{bf16, f16};
6use std::fmt::Display;
7
8use super::{Dialect, Fragment, FragmentIdent, COUNTER_TMP_VAR};
9
10#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
11pub enum Elem<D: Dialect> {
12 TF32,
13 F32,
14 F64,
15 F16,
16 F162,
17 BF16,
18 BF162,
19 I8,
20 I16,
21 I32,
22 I64,
23 U8,
24 U16,
25 U32,
26 U64,
27 Bool,
28 Atomic(AtomicKind<D>),
29 _Dialect(std::marker::PhantomData<D>),
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
33pub enum AtomicKind<D: Dialect> {
34 I32,
35 I64,
36 U32,
37 U64,
38 F16,
39 BF16,
40 F32,
41 F64,
42 _Dialect(std::marker::PhantomData<D>),
44}
45
46impl<D: Dialect> Display for AtomicKind<D> {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 match self {
49 Self::I32 => Elem::<D>::I32.fmt(f),
50 Self::I64 => Elem::<D>::I64.fmt(f),
51 Self::U32 => Elem::<D>::U32.fmt(f),
52 Self::U64 => Elem::<D>::U64.fmt(f),
53 Self::F16 => Elem::<D>::F16.fmt(f),
54 Self::BF16 => Elem::<D>::BF16.fmt(f),
55 Self::F32 => Elem::<D>::F32.fmt(f),
56 Self::F64 => Elem::<D>::F64.fmt(f),
57 Self::_Dialect(_) => Ok(()),
58 }
59 }
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
63pub struct Item<D: Dialect> {
64 pub(crate) elem: Elem<D>,
65 pub(crate) vectorization: usize,
66}
67
68impl<D: Dialect> Display for Elem<D> {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 match self {
71 Elem::F16 => f.write_str("__half"),
72 Elem::F162 => f.write_str("__half2"),
73 Elem::F32 => f.write_str("float"),
74 Elem::F64 => f.write_str("double"),
75 Elem::BF16 => D::bfloat16_type_name(f),
76 Elem::BF162 => D::bfloat162_type_name(f),
77 Elem::TF32 => f.write_str("float"),
78 Elem::I8 => f.write_str("char"),
79 Elem::I16 => f.write_str("short"),
80 Elem::I32 => f.write_str("int"),
81 Elem::I64 => f.write_str("int64"),
82 Elem::U8 => f.write_str("uint8"),
83 Elem::U16 => f.write_str("uint16"),
84 Elem::U32 => f.write_str("uint"),
85 Elem::U64 => f.write_str("uint64"),
86 Elem::Bool => f.write_str("bool"),
87 Elem::Atomic(inner) => inner.fmt(f),
88 Elem::_Dialect(_) => Ok(()),
89 }
90 }
91}
92
93impl<D: Dialect> Display for Item<D> {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 if 1 == self.vectorization {
96 return write!(f, "{}", self.elem);
97 }
98
99 write!(f, "{}_{}", self.elem, self.vectorization)
100 }
101}
102
103pub trait Component<D: Dialect>: Display + FmtLeft {
104 fn item(&self) -> Item<D>;
105 fn is_const(&self) -> bool;
106 fn index(&self, index: usize) -> IndexedVariable<D>;
107 fn elem(&self) -> Elem<D> {
108 *self.item().elem()
109 }
110}
111
112impl<D: Dialect> Component<D> for IndexedVariable<D> {
113 fn item(&self) -> Item<D> {
114 self.var.item()
115 }
116
117 fn index(&self, index: usize) -> IndexedVariable<D> {
118 self.var.index(index)
119 }
120
121 fn is_const(&self) -> bool {
122 matches!(self.var, Variable::LocalConst { .. })
123 }
124}
125
126impl<D: Dialect> Component<D> for Variable<D> {
127 fn index(&self, index: usize) -> IndexedVariable<D> {
128 self.index(index)
129 }
130
131 fn item(&self) -> Item<D> {
132 match self {
133 Variable::GlobalInputArray(_, e) => *e,
134 Variable::GlobalOutputArray(_, e) => *e,
135 Variable::SharedMemory(_, e, _) => *e,
136 Variable::ConstantArray(_, e, _) => *e,
137 Variable::LocalMut { item, .. } => *item,
138 Variable::LocalConst { item, .. } => *item,
139 Variable::Named { item, .. } => *item,
140 Variable::Slice { item, .. } => *item,
141 Variable::ConstantScalar(_, e) => Item::scalar(*e),
142 Variable::GlobalScalar(_, e, _) => Item::scalar(*e),
143 Variable::IdxGlobal => Item::scalar(Elem::U32),
144 Variable::ThreadIdxGlobal => Item::scalar(Elem::U32),
145 Variable::ThreadIdxX => Item::scalar(Elem::U32),
146 Variable::ThreadIdxY => Item::scalar(Elem::U32),
147 Variable::ThreadIdxZ => Item::scalar(Elem::U32),
148 Variable::BlockIdxX => Item::scalar(Elem::U32),
149 Variable::BlockIdxY => Item::scalar(Elem::U32),
150 Variable::BlockIdxZ => Item::scalar(Elem::U32),
151 Variable::AbsoluteIdxX => Item::scalar(Elem::U32),
152 Variable::AbsoluteIdxY => Item::scalar(Elem::U32),
153 Variable::AbsoluteIdxZ => Item::scalar(Elem::U32),
154 Variable::BlockDimX => Item::scalar(Elem::U32),
155 Variable::BlockDimY => Item::scalar(Elem::U32),
156 Variable::BlockDimZ => Item::scalar(Elem::U32),
157 Variable::GridDimX => Item::scalar(Elem::U32),
158 Variable::GridDimY => Item::scalar(Elem::U32),
159 Variable::GridDimZ => Item::scalar(Elem::U32),
160 Variable::LocalArray(_, e, _) => *e,
161 Variable::WarpSize => Item::scalar(Elem::U32),
162 Variable::ThreadIdxWarp => Item::scalar(Elem::U32),
163 Variable::WmmaFragment { frag, .. } => Item::scalar(frag.elem),
164 Variable::BlockIdxGlobal => Item::scalar(Elem::U32),
165 Variable::BlockDimGlobal => Item::scalar(Elem::U32),
166 Variable::GridDimGlobal => Item::scalar(Elem::U32),
167 Variable::Tmp { item, .. } => *item,
168 }
169 }
170
171 fn is_const(&self) -> bool {
172 matches!(self, Variable::LocalConst { .. })
173 }
174}
175
176#[derive(Debug, Clone, Copy, PartialEq)]
177pub enum Variable<D: Dialect> {
178 WarpSize,
179 ThreadIdxWarp,
180 GlobalInputArray(Id, Item<D>),
181 GlobalOutputArray(Id, Item<D>),
182 GlobalScalar(Id, Elem<D>, gpu::Elem),
183 ConstantArray(Id, Item<D>, u32),
184 ConstantScalar(ConstantScalarValue, Elem<D>),
185 LocalMut { id: Id, item: Item<D> },
186 LocalConst { id: Id, item: Item<D> },
187 Named { name: &'static str, item: Item<D> },
188 Slice { id: Id, item: Item<D> },
189 SharedMemory(Id, Item<D>, u32),
190 LocalArray(Id, Item<D>, u32),
191 IdxGlobal,
192 ThreadIdxGlobal,
193 ThreadIdxX,
194 ThreadIdxY,
195 ThreadIdxZ,
196 BlockIdxGlobal,
197 BlockIdxX,
198 BlockIdxY,
199 BlockIdxZ,
200 AbsoluteIdxX,
201 AbsoluteIdxY,
202 AbsoluteIdxZ,
203 BlockDimGlobal,
204 BlockDimX,
205 BlockDimY,
206 BlockDimZ,
207 GridDimGlobal,
208 GridDimX,
209 GridDimY,
210 GridDimZ,
211 WmmaFragment { id: Id, frag: Fragment<D> },
212 Tmp { id: Id, item: Item<D> },
213}
214
215impl<D: Dialect> Display for Variable<D> {
216 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217 match self {
218 Variable::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")),
219 Variable::LocalMut { id, .. } => f.write_fmt(format_args!("l_mut_{id}")),
220 Variable::LocalConst { id, .. } => f.write_fmt(format_args!("l_{id}")),
221 Variable::Named { name, .. } => f.write_fmt(format_args!("{name}")),
222 Variable::Slice { id, .. } => {
223 write!(f, "slice_{id}")
224 }
225 Variable::GlobalOutputArray(number, _) => write!(f, "output_{number}"),
226 Variable::GlobalScalar(number, _, elem) => {
227 write!(f, "scalars_{elem}[{number}]")
228 }
229 Variable::ConstantScalar(number, elem) => match number {
232 ConstantScalarValue::Int(val, kind) => match kind {
233 gpu::IntKind::I8 => write!(f, "{elem}({})", *val as i8),
234 gpu::IntKind::I16 => write!(f, "{elem}({})", *val as i16),
235 gpu::IntKind::I32 => write!(f, "{elem}({})", *val as i32),
236 gpu::IntKind::I64 => write!(f, "{elem}({})", *val),
237 },
238 ConstantScalarValue::Float(val, kind) => match kind {
239 gpu::FloatKind::F16 => {
240 write!(f, "{elem}({:?})", half::f16::from_f64(*val))
241 }
242 gpu::FloatKind::BF16 => {
243 write!(f, "{elem}({:?})", half::bf16::from_f64(*val))
244 }
245 gpu::FloatKind::Flex32 => write!(f, "{elem}({:?})", *val as f32),
246 gpu::FloatKind::TF32 => write!(f, "{elem}({:?})", *val as f32),
247 gpu::FloatKind::F32 => write!(f, "{elem}({:?})", *val as f32),
248 gpu::FloatKind::F64 => write!(f, "{elem}({:?})", *val),
249 },
250 ConstantScalarValue::UInt(val, kind) => match kind {
251 gpu::UIntKind::U8 => write!(f, "{elem}({})", *val as u8),
252 gpu::UIntKind::U16 => write!(f, "{elem}({})", *val as u16),
253 gpu::UIntKind::U32 => write!(f, "{elem}({})", *val as u32),
254 gpu::UIntKind::U64 => write!(f, "{elem}({})", *val),
255 },
256 ConstantScalarValue::Bool(val) => write!(f, "{}", val),
257 },
258 Variable::SharedMemory(number, _, _) => {
259 write!(f, "shared_memory_{number}")
260 }
261 Variable::ConstantArray(number, _, _) => f.write_fmt(format_args!("arrays_{number}")),
262 Variable::ThreadIdxGlobal => f.write_str("threadIdxGlobal"),
263 Variable::ThreadIdxX => f.write_str("threadIdx.x"),
264 Variable::ThreadIdxY => f.write_str("threadIdx.y"),
265 Variable::ThreadIdxZ => f.write_str("threadIdx.z"),
266 Variable::BlockIdxGlobal => f.write_str("blockIdxGlobal"),
267 Variable::BlockIdxX => f.write_str("blockIdx.x"),
268 Variable::BlockIdxY => f.write_str("blockIdx.y"),
269 Variable::BlockIdxZ => f.write_str("blockIdx.z"),
270 Variable::BlockDimGlobal => f.write_str("blockDimGlobal"),
271 Variable::BlockDimX => f.write_str("blockDim.x"),
272 Variable::BlockDimY => f.write_str("blockDim.y"),
273 Variable::BlockDimZ => f.write_str("blockDim.z"),
274 Variable::IdxGlobal => f.write_str("idxGlobal"),
275 Variable::GridDimX => f.write_str("gridDim.x"),
276 Variable::GridDimY => f.write_str("gridDim.y"),
277 Variable::GridDimZ => f.write_str("gridDim.z"),
278 Variable::AbsoluteIdxX => f.write_str("absoluteIdx.x"),
279 Variable::AbsoluteIdxY => f.write_str("absoluteIdx.y"),
280 Variable::AbsoluteIdxZ => f.write_str("absoluteIdx.z"),
281 Variable::LocalArray(id, _, _) => {
282 write!(f, "l_arr_{}", id)
283 }
284 Variable::WarpSize => f.write_str("warpSize"),
285 Variable::ThreadIdxWarp => f.write_str("threadIdxGlobal % warpSize"),
286 Variable::WmmaFragment { id: index, frag } => {
287 let name = match frag.ident {
288 FragmentIdent::A => "a",
289 FragmentIdent::B => "b",
290 FragmentIdent::Accumulator => "acc",
291 FragmentIdent::_Dialect(_) => "",
292 };
293 write!(f, "frag_{name}_{index}")
294 }
295 Variable::GridDimGlobal => f.write_str("gridDimGlobal"),
296 Self::Tmp { id, .. } => write!(f, "_tmp_{id}"),
297 }
298 }
299}
300
301#[derive(new)]
302pub struct OptimizedArgs<const N: usize, D: Dialect> {
303 pub args: [Variable<D>; N],
304 pub optimization_factor: Option<usize>,
305}
306
307impl<D: Dialect> Variable<D> {
308 pub fn is_optimized(&self) -> bool {
309 self.item().is_optimized()
310 }
311
312 pub fn tmp(item: Item<D>) -> Self {
313 let inc = COUNTER_TMP_VAR.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
314
315 Variable::Tmp {
316 id: inc as Id,
317 item,
318 }
319 }
320
321 pub fn optimized_args<const N: usize>(args: [Self; N]) -> OptimizedArgs<N, D> {
322 let args_after = args.map(|a| a.optimized());
323
324 let item_reference_after = args_after[0].item();
325
326 let is_optimized = args_after
327 .iter()
328 .all(|var| var.elem() == item_reference_after.elem && var.is_optimized());
329
330 if is_optimized {
331 let vectorization_before = args
332 .iter()
333 .map(|var| var.item().vectorization)
334 .max()
335 .unwrap();
336 let vectorization_after = args_after
337 .iter()
338 .map(|var| var.item().vectorization)
339 .max()
340 .unwrap();
341
342 OptimizedArgs::new(args_after, Some(vectorization_before / vectorization_after))
343 } else {
344 OptimizedArgs::new(args, None)
345 }
346 }
347
348 pub fn optimized(&self) -> Self {
349 match self {
350 Variable::GlobalInputArray(id, item) => {
351 Variable::GlobalInputArray(*id, item.optimized())
352 }
353 Variable::GlobalOutputArray(id, item) => {
354 Variable::GlobalOutputArray(*id, item.optimized())
355 }
356 Variable::LocalMut { id, item } => Variable::LocalMut {
357 id: *id,
358 item: item.optimized(),
359 },
360 Variable::LocalConst { id, item } => Variable::LocalConst {
361 id: *id,
362 item: item.optimized(),
363 },
364 Variable::Slice { id, item } => Variable::Slice {
365 id: *id,
366 item: item.optimized(),
367 },
368 Variable::SharedMemory(id, item, size) => {
369 let before = item.vectorization;
370 let item = item.optimized();
371 let after = item.vectorization;
372 let scaling = (before / after) as u32;
373
374 Variable::SharedMemory(*id, item, size / scaling)
375 }
376 Variable::LocalArray(id, item, size) => {
377 let before = item.vectorization;
378 let item = item.optimized();
379 let after = item.vectorization;
380 let scaling = (before / after) as u32;
381
382 Variable::LocalArray(*id, item.optimized(), size / scaling)
383 }
384 _ => *self,
385 }
386 }
387
388 pub fn is_always_scalar(&self) -> bool {
389 match self {
390 Variable::GlobalScalar(_, _, _) => true,
391 Variable::ConstantScalar(_, _) => true,
392 Variable::IdxGlobal => true,
393 Variable::ThreadIdxGlobal => true,
394 Variable::ThreadIdxX => true,
395 Variable::ThreadIdxY => true,
396 Variable::ThreadIdxZ => true,
397 Variable::GlobalInputArray(_, _) => false,
398 Variable::GlobalOutputArray(_, _) => false,
399 Variable::SharedMemory(_, _, _) => false,
400 Variable::ConstantArray(_, _, _) => false,
401 Variable::LocalMut { .. } => false,
402 Variable::LocalConst { .. } => false,
403 Variable::Named { .. } => false,
404 Variable::Slice { .. } => false,
405 Variable::BlockIdxX => true,
406 Variable::BlockIdxY => true,
407 Variable::BlockIdxZ => true,
408 Variable::AbsoluteIdxX => true,
409 Variable::AbsoluteIdxY => true,
410 Variable::AbsoluteIdxZ => true,
411 Variable::BlockDimX => true,
412 Variable::BlockDimY => true,
413 Variable::BlockDimZ => true,
414 Variable::GridDimX => true,
415 Variable::GridDimY => true,
416 Variable::GridDimZ => true,
417 Variable::LocalArray(_, _, _) => false,
418 Variable::WarpSize => true,
419 Variable::ThreadIdxWarp => true,
420 Variable::WmmaFragment { .. } => false,
421 Variable::BlockIdxGlobal => true,
422 Variable::BlockDimGlobal => true,
423 Variable::GridDimGlobal => true,
424 Variable::Tmp { .. } => false,
425 }
426 }
427
428 pub fn index(&self, index: usize) -> IndexedVariable<D> {
429 IndexedVariable {
430 var: *self,
431 index,
432 optimized: self.is_optimized(),
433 }
434 }
435
436 pub fn const_qualifier(&self) -> &str {
437 if self.is_const() {
438 " const"
439 } else {
440 ""
441 }
442 }
443}
444
445pub trait FmtLeft: Display {
446 fn fmt_left(&self) -> String;
447}
448
449impl<D: Dialect> FmtLeft for Variable<D> {
450 fn fmt_left(&self) -> String {
451 match self {
452 Self::LocalConst { item, .. } => format!("const {item} {self}"),
453 Variable::Tmp { item, .. } => format!("{item} {self}"),
454 var => format!("{var}"),
455 }
456 }
457}
458
459impl<D: Dialect> FmtLeft for IndexedVariable<D> {
460 fn fmt_left(&self) -> String {
461 match self.var {
462 Variable::LocalConst { item, .. } => format!("const {item} {self}"),
463 Variable::Tmp { item, .. } => format!("{item} {self}"),
464 _ => format!("{self}"),
465 }
466 }
467}
468
469impl FmtLeft for &String {
470 fn fmt_left(&self) -> String {
471 self.to_string()
472 }
473}
474
475#[derive(Debug, Clone)]
476pub struct IndexedVariable<D: Dialect> {
477 var: Variable<D>,
478 optimized: bool,
479 index: usize,
480}
481
482impl<D: Dialect> Display for IndexedVariable<D> {
483 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
484 let var = &self.var;
485 let ref_ = matches!(var, Variable::LocalConst { .. })
486 .then_some("const&")
487 .unwrap_or("&");
488
489 if self.var.item().vectorization > 1 {
490 if self.optimized {
491 let item = self.var.item();
492 write!(
493 f,
494 "(reinterpret_cast<{item} {ref_}>({var})).i_{}",
495 self.index
496 )
497 } else {
498 write!(f, "{var}.i_{}", self.index)
499 }
500 } else if self.optimized {
501 let item = self.var.item();
502 write!(f, "reinterpret_cast<{item} {ref_}>({var})")
503 } else {
504 write!(f, "{var}")
505 }
506 }
507}
508impl<D: Dialect> Item<D> {
509 pub fn elem(&self) -> &Elem<D> {
510 &self.elem
511 }
512
513 pub fn de_optimized(&self) -> Self {
514 match self.elem {
515 Elem::F162 => Item::new(Elem::F16, self.vectorization * 2),
516 Elem::BF162 => Item::new(Elem::BF16, self.vectorization * 2),
517 _ => *self,
518 }
519 }
520
521 pub fn new(elem: Elem<D>, vectorization: usize) -> Self {
522 Self {
523 elem,
524 vectorization,
525 }
526 }
527 pub fn scalar(elem: Elem<D>) -> Self {
528 Self {
529 elem,
530 vectorization: 1,
531 }
532 }
533
534 pub fn is_optimized(&self) -> bool {
535 matches!(self.elem, Elem::F162 | Elem::BF162)
536 }
537
538 pub fn optimized(&self) -> Item<D> {
539 if self.vectorization == 1 {
540 return *self;
541 }
542
543 if self.vectorization % 2 != 0 {
544 return *self;
545 }
546
547 match self.elem {
548 Elem::F16 => Item {
549 elem: Elem::F162,
550 vectorization: self.vectorization / 2,
551 },
552 Elem::BF16 => Item {
553 elem: Elem::BF162,
554 vectorization: self.vectorization / 2,
555 },
556 _ => *self,
557 }
558 }
559}
560
561impl<D: Dialect> Elem<D> {
562 pub const fn size(&self) -> usize {
563 match self {
564 Elem::F16 => core::mem::size_of::<f16>(),
565 Elem::F162 => 2 * core::mem::size_of::<f16>(),
566 Elem::BF162 => 2 * core::mem::size_of::<bf16>(),
567 Elem::BF16 => core::mem::size_of::<bf16>(),
568 Elem::TF32 => core::mem::size_of::<tf32>(),
569 Elem::F32 => core::mem::size_of::<f32>(),
570 Elem::F64 => core::mem::size_of::<f64>(),
571 Elem::I8 => core::mem::size_of::<i8>(),
572 Elem::I16 => core::mem::size_of::<i16>(),
573 Elem::I32 => core::mem::size_of::<i32>(),
574 Elem::I64 => core::mem::size_of::<i64>(),
575 Elem::U8 => core::mem::size_of::<u8>(),
576 Elem::U16 => core::mem::size_of::<u16>(),
577 Elem::U32 => core::mem::size_of::<u32>(),
578 Elem::U64 => core::mem::size_of::<u64>(),
579 Elem::Bool => core::mem::size_of::<bool>(),
580 Elem::Atomic(AtomicKind::I32) => core::mem::size_of::<i32>(),
581 Elem::Atomic(AtomicKind::I64) => core::mem::size_of::<i64>(),
582 Elem::Atomic(AtomicKind::U32) => core::mem::size_of::<u32>(),
583 Elem::Atomic(AtomicKind::U64) => core::mem::size_of::<u64>(),
584 Elem::Atomic(AtomicKind::F16) => core::mem::size_of::<f16>(),
585 Elem::Atomic(AtomicKind::BF16) => core::mem::size_of::<bf16>(),
586 Elem::Atomic(AtomicKind::F32) => core::mem::size_of::<f32>(),
587 Elem::Atomic(AtomicKind::F64) => core::mem::size_of::<f64>(),
588 Elem::Atomic(AtomicKind::_Dialect(_)) => 0,
589 Elem::_Dialect(_) => 0,
590 }
591 }
592}