cubecl_cpp/shared/
element.rs

1use cubecl_core::{
2    ir::{self as gpu, ConstantScalarValue, Id},
3    tf32,
4};
5use half::{bf16, f16};
6use std::fmt::Display;
7
8use super::{Dialect, Fragment, FragmentIdent, COUNTER_TMP_VAR};
9
10#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
11pub enum Elem<D: Dialect> {
12    TF32,
13    F32,
14    F64,
15    F16,
16    F162,
17    BF16,
18    BF162,
19    I8,
20    I16,
21    I32,
22    I64,
23    U8,
24    U16,
25    U32,
26    U64,
27    Bool,
28    Atomic(AtomicKind<D>),
29    _Dialect(std::marker::PhantomData<D>),
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
33pub enum AtomicKind<D: Dialect> {
34    I32,
35    I64,
36    U32,
37    U64,
38    F16,
39    BF16,
40    F32,
41    F64,
42    /// Required to construct the inner `Elem` of the atomic value
43    _Dialect(std::marker::PhantomData<D>),
44}
45
46impl<D: Dialect> Display for AtomicKind<D> {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        match self {
49            Self::I32 => Elem::<D>::I32.fmt(f),
50            Self::I64 => Elem::<D>::I64.fmt(f),
51            Self::U32 => Elem::<D>::U32.fmt(f),
52            Self::U64 => Elem::<D>::U64.fmt(f),
53            Self::F16 => Elem::<D>::F16.fmt(f),
54            Self::BF16 => Elem::<D>::BF16.fmt(f),
55            Self::F32 => Elem::<D>::F32.fmt(f),
56            Self::F64 => Elem::<D>::F64.fmt(f),
57            Self::_Dialect(_) => Ok(()),
58        }
59    }
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
63pub struct Item<D: Dialect> {
64    pub(crate) elem: Elem<D>,
65    pub(crate) vectorization: usize,
66}
67
68impl<D: Dialect> Display for Elem<D> {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        match self {
71            Elem::F16 => f.write_str("__half"),
72            Elem::F162 => f.write_str("__half2"),
73            Elem::F32 => f.write_str("float"),
74            Elem::F64 => f.write_str("double"),
75            Elem::BF16 => D::bfloat16_type_name(f),
76            Elem::BF162 => D::bfloat162_type_name(f),
77            Elem::TF32 => f.write_str("float"),
78            Elem::I8 => f.write_str("char"),
79            Elem::I16 => f.write_str("short"),
80            Elem::I32 => f.write_str("int"),
81            Elem::I64 => f.write_str("int64"),
82            Elem::U8 => f.write_str("uint8"),
83            Elem::U16 => f.write_str("uint16"),
84            Elem::U32 => f.write_str("uint"),
85            Elem::U64 => f.write_str("uint64"),
86            Elem::Bool => f.write_str("bool"),
87            Elem::Atomic(inner) => inner.fmt(f),
88            Elem::_Dialect(_) => Ok(()),
89        }
90    }
91}
92
93impl<D: Dialect> Display for Item<D> {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        if 1 == self.vectorization {
96            return write!(f, "{}", self.elem);
97        }
98
99        write!(f, "{}_{}", self.elem, self.vectorization)
100    }
101}
102
103pub trait Component<D: Dialect>: Display + FmtLeft {
104    fn item(&self) -> Item<D>;
105    fn is_const(&self) -> bool;
106    fn index(&self, index: usize) -> IndexedVariable<D>;
107    fn elem(&self) -> Elem<D> {
108        *self.item().elem()
109    }
110}
111
112impl<D: Dialect> Component<D> for IndexedVariable<D> {
113    fn item(&self) -> Item<D> {
114        self.var.item()
115    }
116
117    fn index(&self, index: usize) -> IndexedVariable<D> {
118        self.var.index(index)
119    }
120
121    fn is_const(&self) -> bool {
122        matches!(self.var, Variable::LocalConst { .. })
123    }
124}
125
126impl<D: Dialect> Component<D> for Variable<D> {
127    fn index(&self, index: usize) -> IndexedVariable<D> {
128        self.index(index)
129    }
130
131    fn item(&self) -> Item<D> {
132        match self {
133            Variable::GlobalInputArray(_, e) => *e,
134            Variable::GlobalOutputArray(_, e) => *e,
135            Variable::SharedMemory(_, e, _) => *e,
136            Variable::ConstantArray(_, e, _) => *e,
137            Variable::LocalMut { item, .. } => *item,
138            Variable::LocalConst { item, .. } => *item,
139            Variable::Named { item, .. } => *item,
140            Variable::Slice { item, .. } => *item,
141            Variable::ConstantScalar(_, e) => Item::scalar(*e),
142            Variable::GlobalScalar(_, e, _) => Item::scalar(*e),
143            Variable::IdxGlobal => Item::scalar(Elem::U32),
144            Variable::ThreadIdxGlobal => Item::scalar(Elem::U32),
145            Variable::ThreadIdxX => Item::scalar(Elem::U32),
146            Variable::ThreadIdxY => Item::scalar(Elem::U32),
147            Variable::ThreadIdxZ => Item::scalar(Elem::U32),
148            Variable::BlockIdxX => Item::scalar(Elem::U32),
149            Variable::BlockIdxY => Item::scalar(Elem::U32),
150            Variable::BlockIdxZ => Item::scalar(Elem::U32),
151            Variable::AbsoluteIdxX => Item::scalar(Elem::U32),
152            Variable::AbsoluteIdxY => Item::scalar(Elem::U32),
153            Variable::AbsoluteIdxZ => Item::scalar(Elem::U32),
154            Variable::BlockDimX => Item::scalar(Elem::U32),
155            Variable::BlockDimY => Item::scalar(Elem::U32),
156            Variable::BlockDimZ => Item::scalar(Elem::U32),
157            Variable::GridDimX => Item::scalar(Elem::U32),
158            Variable::GridDimY => Item::scalar(Elem::U32),
159            Variable::GridDimZ => Item::scalar(Elem::U32),
160            Variable::LocalArray(_, e, _) => *e,
161            Variable::WarpSize => Item::scalar(Elem::U32),
162            Variable::ThreadIdxWarp => Item::scalar(Elem::U32),
163            Variable::WmmaFragment { frag, .. } => Item::scalar(frag.elem),
164            Variable::BlockIdxGlobal => Item::scalar(Elem::U32),
165            Variable::BlockDimGlobal => Item::scalar(Elem::U32),
166            Variable::GridDimGlobal => Item::scalar(Elem::U32),
167            Variable::Tmp { item, .. } => *item,
168        }
169    }
170
171    fn is_const(&self) -> bool {
172        matches!(self, Variable::LocalConst { .. })
173    }
174}
175
176#[derive(Debug, Clone, Copy, PartialEq)]
177pub enum Variable<D: Dialect> {
178    WarpSize,
179    ThreadIdxWarp,
180    GlobalInputArray(Id, Item<D>),
181    GlobalOutputArray(Id, Item<D>),
182    GlobalScalar(Id, Elem<D>, gpu::Elem),
183    ConstantArray(Id, Item<D>, u32),
184    ConstantScalar(ConstantScalarValue, Elem<D>),
185    LocalMut { id: Id, item: Item<D> },
186    LocalConst { id: Id, item: Item<D> },
187    Named { name: &'static str, item: Item<D> },
188    Slice { id: Id, item: Item<D> },
189    SharedMemory(Id, Item<D>, u32),
190    LocalArray(Id, Item<D>, u32),
191    IdxGlobal,
192    ThreadIdxGlobal,
193    ThreadIdxX,
194    ThreadIdxY,
195    ThreadIdxZ,
196    BlockIdxGlobal,
197    BlockIdxX,
198    BlockIdxY,
199    BlockIdxZ,
200    AbsoluteIdxX,
201    AbsoluteIdxY,
202    AbsoluteIdxZ,
203    BlockDimGlobal,
204    BlockDimX,
205    BlockDimY,
206    BlockDimZ,
207    GridDimGlobal,
208    GridDimX,
209    GridDimY,
210    GridDimZ,
211    WmmaFragment { id: Id, frag: Fragment<D> },
212    Tmp { id: Id, item: Item<D> },
213}
214
215impl<D: Dialect> Display for Variable<D> {
216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217        match self {
218            Variable::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")),
219            Variable::LocalMut { id, .. } => f.write_fmt(format_args!("l_mut_{id}")),
220            Variable::LocalConst { id, .. } => f.write_fmt(format_args!("l_{id}")),
221            Variable::Named { name, .. } => f.write_fmt(format_args!("{name}")),
222            Variable::Slice { id, .. } => {
223                write!(f, "slice_{id}")
224            }
225            Variable::GlobalOutputArray(number, _) => write!(f, "output_{number}"),
226            Variable::GlobalScalar(number, _, elem) => {
227                write!(f, "scalars_{elem}[{number}]")
228            }
229            // We do the conversion in Rust and then render the number to avoid overflow or other
230            // precision related problems.
231            Variable::ConstantScalar(number, elem) => match number {
232                ConstantScalarValue::Int(val, kind) => match kind {
233                    gpu::IntKind::I8 => write!(f, "{elem}({})", *val as i8),
234                    gpu::IntKind::I16 => write!(f, "{elem}({})", *val as i16),
235                    gpu::IntKind::I32 => write!(f, "{elem}({})", *val as i32),
236                    gpu::IntKind::I64 => write!(f, "{elem}({})", *val),
237                },
238                ConstantScalarValue::Float(val, kind) => match kind {
239                    gpu::FloatKind::F16 => {
240                        write!(f, "{elem}({:?})", half::f16::from_f64(*val))
241                    }
242                    gpu::FloatKind::BF16 => {
243                        write!(f, "{elem}({:?})", half::bf16::from_f64(*val))
244                    }
245                    gpu::FloatKind::Flex32 => write!(f, "{elem}({:?})", *val as f32),
246                    gpu::FloatKind::TF32 => write!(f, "{elem}({:?})", *val as f32),
247                    gpu::FloatKind::F32 => write!(f, "{elem}({:?})", *val as f32),
248                    gpu::FloatKind::F64 => write!(f, "{elem}({:?})", *val),
249                },
250                ConstantScalarValue::UInt(val, kind) => match kind {
251                    gpu::UIntKind::U8 => write!(f, "{elem}({})", *val as u8),
252                    gpu::UIntKind::U16 => write!(f, "{elem}({})", *val as u16),
253                    gpu::UIntKind::U32 => write!(f, "{elem}({})", *val as u32),
254                    gpu::UIntKind::U64 => write!(f, "{elem}({})", *val),
255                },
256                ConstantScalarValue::Bool(val) => write!(f, "{}", val),
257            },
258            Variable::SharedMemory(number, _, _) => {
259                write!(f, "shared_memory_{number}")
260            }
261            Variable::ConstantArray(number, _, _) => f.write_fmt(format_args!("arrays_{number}")),
262            Variable::ThreadIdxGlobal => f.write_str("threadIdxGlobal"),
263            Variable::ThreadIdxX => f.write_str("threadIdx.x"),
264            Variable::ThreadIdxY => f.write_str("threadIdx.y"),
265            Variable::ThreadIdxZ => f.write_str("threadIdx.z"),
266            Variable::BlockIdxGlobal => f.write_str("blockIdxGlobal"),
267            Variable::BlockIdxX => f.write_str("blockIdx.x"),
268            Variable::BlockIdxY => f.write_str("blockIdx.y"),
269            Variable::BlockIdxZ => f.write_str("blockIdx.z"),
270            Variable::BlockDimGlobal => f.write_str("blockDimGlobal"),
271            Variable::BlockDimX => f.write_str("blockDim.x"),
272            Variable::BlockDimY => f.write_str("blockDim.y"),
273            Variable::BlockDimZ => f.write_str("blockDim.z"),
274            Variable::IdxGlobal => f.write_str("idxGlobal"),
275            Variable::GridDimX => f.write_str("gridDim.x"),
276            Variable::GridDimY => f.write_str("gridDim.y"),
277            Variable::GridDimZ => f.write_str("gridDim.z"),
278            Variable::AbsoluteIdxX => f.write_str("absoluteIdx.x"),
279            Variable::AbsoluteIdxY => f.write_str("absoluteIdx.y"),
280            Variable::AbsoluteIdxZ => f.write_str("absoluteIdx.z"),
281            Variable::LocalArray(id, _, _) => {
282                write!(f, "l_arr_{}", id)
283            }
284            Variable::WarpSize => f.write_str("warpSize"),
285            Variable::ThreadIdxWarp => f.write_str("threadIdxGlobal % warpSize"),
286            Variable::WmmaFragment { id: index, frag } => {
287                let name = match frag.ident {
288                    FragmentIdent::A => "a",
289                    FragmentIdent::B => "b",
290                    FragmentIdent::Accumulator => "acc",
291                    FragmentIdent::_Dialect(_) => "",
292                };
293                write!(f, "frag_{name}_{index}")
294            }
295            Variable::GridDimGlobal => f.write_str("gridDimGlobal"),
296            Self::Tmp { id, .. } => write!(f, "_tmp_{id}"),
297        }
298    }
299}
300
301#[derive(new)]
302pub struct OptimizedArgs<const N: usize, D: Dialect> {
303    pub args: [Variable<D>; N],
304    pub optimization_factor: Option<usize>,
305}
306
307impl<D: Dialect> Variable<D> {
308    pub fn is_optimized(&self) -> bool {
309        self.item().is_optimized()
310    }
311
312    pub fn tmp(item: Item<D>) -> Self {
313        let inc = COUNTER_TMP_VAR.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
314
315        Variable::Tmp {
316            id: inc as Id,
317            item,
318        }
319    }
320
321    pub fn optimized_args<const N: usize>(args: [Self; N]) -> OptimizedArgs<N, D> {
322        let args_after = args.map(|a| a.optimized());
323
324        let item_reference_after = args_after[0].item();
325
326        let is_optimized = args_after
327            .iter()
328            .all(|var| var.elem() == item_reference_after.elem && var.is_optimized());
329
330        if is_optimized {
331            let vectorization_before = args
332                .iter()
333                .map(|var| var.item().vectorization)
334                .max()
335                .unwrap();
336            let vectorization_after = args_after
337                .iter()
338                .map(|var| var.item().vectorization)
339                .max()
340                .unwrap();
341
342            OptimizedArgs::new(args_after, Some(vectorization_before / vectorization_after))
343        } else {
344            OptimizedArgs::new(args, None)
345        }
346    }
347
348    pub fn optimized(&self) -> Self {
349        match self {
350            Variable::GlobalInputArray(id, item) => {
351                Variable::GlobalInputArray(*id, item.optimized())
352            }
353            Variable::GlobalOutputArray(id, item) => {
354                Variable::GlobalOutputArray(*id, item.optimized())
355            }
356            Variable::LocalMut { id, item } => Variable::LocalMut {
357                id: *id,
358                item: item.optimized(),
359            },
360            Variable::LocalConst { id, item } => Variable::LocalConst {
361                id: *id,
362                item: item.optimized(),
363            },
364            Variable::Slice { id, item } => Variable::Slice {
365                id: *id,
366                item: item.optimized(),
367            },
368            Variable::SharedMemory(id, item, size) => {
369                let before = item.vectorization;
370                let item = item.optimized();
371                let after = item.vectorization;
372                let scaling = (before / after) as u32;
373
374                Variable::SharedMemory(*id, item, size / scaling)
375            }
376            Variable::LocalArray(id, item, size) => {
377                let before = item.vectorization;
378                let item = item.optimized();
379                let after = item.vectorization;
380                let scaling = (before / after) as u32;
381
382                Variable::LocalArray(*id, item.optimized(), size / scaling)
383            }
384            _ => *self,
385        }
386    }
387
388    pub fn is_always_scalar(&self) -> bool {
389        match self {
390            Variable::GlobalScalar(_, _, _) => true,
391            Variable::ConstantScalar(_, _) => true,
392            Variable::IdxGlobal => true,
393            Variable::ThreadIdxGlobal => true,
394            Variable::ThreadIdxX => true,
395            Variable::ThreadIdxY => true,
396            Variable::ThreadIdxZ => true,
397            Variable::GlobalInputArray(_, _) => false,
398            Variable::GlobalOutputArray(_, _) => false,
399            Variable::SharedMemory(_, _, _) => false,
400            Variable::ConstantArray(_, _, _) => false,
401            Variable::LocalMut { .. } => false,
402            Variable::LocalConst { .. } => false,
403            Variable::Named { .. } => false,
404            Variable::Slice { .. } => false,
405            Variable::BlockIdxX => true,
406            Variable::BlockIdxY => true,
407            Variable::BlockIdxZ => true,
408            Variable::AbsoluteIdxX => true,
409            Variable::AbsoluteIdxY => true,
410            Variable::AbsoluteIdxZ => true,
411            Variable::BlockDimX => true,
412            Variable::BlockDimY => true,
413            Variable::BlockDimZ => true,
414            Variable::GridDimX => true,
415            Variable::GridDimY => true,
416            Variable::GridDimZ => true,
417            Variable::LocalArray(_, _, _) => false,
418            Variable::WarpSize => true,
419            Variable::ThreadIdxWarp => true,
420            Variable::WmmaFragment { .. } => false,
421            Variable::BlockIdxGlobal => true,
422            Variable::BlockDimGlobal => true,
423            Variable::GridDimGlobal => true,
424            Variable::Tmp { .. } => false,
425        }
426    }
427
428    pub fn index(&self, index: usize) -> IndexedVariable<D> {
429        IndexedVariable {
430            var: *self,
431            index,
432            optimized: self.is_optimized(),
433        }
434    }
435
436    pub fn const_qualifier(&self) -> &str {
437        if self.is_const() {
438            " const"
439        } else {
440            ""
441        }
442    }
443}
444
445pub trait FmtLeft: Display {
446    fn fmt_left(&self) -> String;
447}
448
449impl<D: Dialect> FmtLeft for Variable<D> {
450    fn fmt_left(&self) -> String {
451        match self {
452            Self::LocalConst { item, .. } => format!("const {item} {self}"),
453            Variable::Tmp { item, .. } => format!("{item} {self}"),
454            var => format!("{var}"),
455        }
456    }
457}
458
459impl<D: Dialect> FmtLeft for IndexedVariable<D> {
460    fn fmt_left(&self) -> String {
461        match self.var {
462            Variable::LocalConst { item, .. } => format!("const {item} {self}"),
463            Variable::Tmp { item, .. } => format!("{item} {self}"),
464            _ => format!("{self}"),
465        }
466    }
467}
468
469impl FmtLeft for &String {
470    fn fmt_left(&self) -> String {
471        self.to_string()
472    }
473}
474
475#[derive(Debug, Clone)]
476pub struct IndexedVariable<D: Dialect> {
477    var: Variable<D>,
478    optimized: bool,
479    index: usize,
480}
481
482impl<D: Dialect> Display for IndexedVariable<D> {
483    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
484        let var = &self.var;
485        let ref_ = matches!(var, Variable::LocalConst { .. })
486            .then_some("const&")
487            .unwrap_or("&");
488
489        if self.var.item().vectorization > 1 {
490            if self.optimized {
491                let item = self.var.item();
492                write!(
493                    f,
494                    "(reinterpret_cast<{item} {ref_}>({var})).i_{}",
495                    self.index
496                )
497            } else {
498                write!(f, "{var}.i_{}", self.index)
499            }
500        } else if self.optimized {
501            let item = self.var.item();
502            write!(f, "reinterpret_cast<{item} {ref_}>({var})")
503        } else {
504            write!(f, "{var}")
505        }
506    }
507}
508impl<D: Dialect> Item<D> {
509    pub fn elem(&self) -> &Elem<D> {
510        &self.elem
511    }
512
513    pub fn de_optimized(&self) -> Self {
514        match self.elem {
515            Elem::F162 => Item::new(Elem::F16, self.vectorization * 2),
516            Elem::BF162 => Item::new(Elem::BF16, self.vectorization * 2),
517            _ => *self,
518        }
519    }
520
521    pub fn new(elem: Elem<D>, vectorization: usize) -> Self {
522        Self {
523            elem,
524            vectorization,
525        }
526    }
527    pub fn scalar(elem: Elem<D>) -> Self {
528        Self {
529            elem,
530            vectorization: 1,
531        }
532    }
533
534    pub fn is_optimized(&self) -> bool {
535        matches!(self.elem, Elem::F162 | Elem::BF162)
536    }
537
538    pub fn optimized(&self) -> Item<D> {
539        if self.vectorization == 1 {
540            return *self;
541        }
542
543        if self.vectorization % 2 != 0 {
544            return *self;
545        }
546
547        match self.elem {
548            Elem::F16 => Item {
549                elem: Elem::F162,
550                vectorization: self.vectorization / 2,
551            },
552            Elem::BF16 => Item {
553                elem: Elem::BF162,
554                vectorization: self.vectorization / 2,
555            },
556            _ => *self,
557        }
558    }
559}
560
561impl<D: Dialect> Elem<D> {
562    pub const fn size(&self) -> usize {
563        match self {
564            Elem::F16 => core::mem::size_of::<f16>(),
565            Elem::F162 => 2 * core::mem::size_of::<f16>(),
566            Elem::BF162 => 2 * core::mem::size_of::<bf16>(),
567            Elem::BF16 => core::mem::size_of::<bf16>(),
568            Elem::TF32 => core::mem::size_of::<tf32>(),
569            Elem::F32 => core::mem::size_of::<f32>(),
570            Elem::F64 => core::mem::size_of::<f64>(),
571            Elem::I8 => core::mem::size_of::<i8>(),
572            Elem::I16 => core::mem::size_of::<i16>(),
573            Elem::I32 => core::mem::size_of::<i32>(),
574            Elem::I64 => core::mem::size_of::<i64>(),
575            Elem::U8 => core::mem::size_of::<u8>(),
576            Elem::U16 => core::mem::size_of::<u16>(),
577            Elem::U32 => core::mem::size_of::<u32>(),
578            Elem::U64 => core::mem::size_of::<u64>(),
579            Elem::Bool => core::mem::size_of::<bool>(),
580            Elem::Atomic(AtomicKind::I32) => core::mem::size_of::<i32>(),
581            Elem::Atomic(AtomicKind::I64) => core::mem::size_of::<i64>(),
582            Elem::Atomic(AtomicKind::U32) => core::mem::size_of::<u32>(),
583            Elem::Atomic(AtomicKind::U64) => core::mem::size_of::<u64>(),
584            Elem::Atomic(AtomicKind::F16) => core::mem::size_of::<f16>(),
585            Elem::Atomic(AtomicKind::BF16) => core::mem::size_of::<bf16>(),
586            Elem::Atomic(AtomicKind::F32) => core::mem::size_of::<f32>(),
587            Elem::Atomic(AtomicKind::F64) => core::mem::size_of::<f64>(),
588            Elem::Atomic(AtomicKind::_Dialect(_)) => 0,
589            Elem::_Dialect(_) => 0,
590        }
591    }
592}