cubecl_cpp/shared/
item.rs1use std::fmt::Display;
2
3use crate::shared::AtomicKind;
4
5use super::{Dialect, Elem};
6
7#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
8pub struct Item<D: Dialect> {
9 pub(crate) elem: Elem<D>,
10 pub(crate) vectorization: usize,
11 pub(crate) native: bool,
12}
13
14impl<D: Dialect> Display for Item<D> {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 D::compile_item(f, self)
17 }
18}
19
20impl<D: Dialect> Item<D> {
21 pub fn elem(&self) -> &Elem<D> {
22 &self.elem
23 }
24
25 pub const fn size(&self) -> usize {
26 self.elem.size() * self.vectorization
27 }
28
29 pub fn new(elem: Elem<D>, vectorization: usize, native: bool) -> Self {
30 Self {
31 elem,
32 vectorization,
33 native,
34 }
35 }
36 pub fn scalar(elem: Elem<D>, native: bool) -> Self {
37 Self {
38 elem,
39 vectorization: 1,
40 native,
41 }
42 }
43
44 pub fn can_be_optimized(&self) -> bool {
45 D::item_can_be_optimized()
46 }
47
48 pub fn is_optimized(&self) -> bool {
49 matches!(
50 self.elem,
51 Elem::F16x2
52 | Elem::BF16x2
53 | Elem::Atomic(AtomicKind::F16x2)
54 | Elem::Atomic(AtomicKind::BF16x2)
55 | Elem::FP4x2(_)
56 | Elem::FP6x2(_)
57 | Elem::FP8x2(_)
58 )
59 }
60
61 pub fn is_atomic(&self) -> bool {
62 matches!(self.elem, Elem::Atomic(_))
63 }
64
65 pub fn optimized(&self) -> Item<D> {
66 if !self.can_be_optimized() || !self.vectorization.is_multiple_of(2) {
67 return *self;
68 }
69
70 match self.elem {
71 Elem::F16 => Item {
72 elem: Elem::F16x2,
73 vectorization: self.vectorization / 2,
74 native: self.native,
75 },
76 Elem::Atomic(AtomicKind::F16) => Item {
77 elem: Elem::Atomic(AtomicKind::F16x2),
78 vectorization: self.vectorization / 2,
79 native: self.native,
80 },
81 Elem::BF16 => Item {
82 elem: Elem::BF16x2,
83 vectorization: self.vectorization / 2,
84 native: self.native,
85 },
86 Elem::Atomic(AtomicKind::BF16) => Item {
87 elem: Elem::Atomic(AtomicKind::BF16x2),
88 vectorization: self.vectorization / 2,
89 native: self.native,
90 },
91 Elem::FP4(kind) => Item {
92 elem: Elem::FP4x2(kind),
93 vectorization: self.vectorization / 2,
94 native: self.native,
95 },
96 Elem::FP6(kind) => Item {
97 elem: Elem::FP6x2(kind),
98 vectorization: self.vectorization / 2,
99 native: self.native,
100 },
101 Elem::FP8(kind) => Item {
102 elem: Elem::FP8x2(kind),
103 vectorization: self.vectorization / 2,
104 native: self.native,
105 },
106 _ => *self,
107 }
108 }
109
110 pub fn packing_factor(&self) -> usize {
112 self.elem.packing_factor()
113 }
114
115 pub fn de_optimized(&self) -> Self {
116 match self.elem {
117 Elem::FP4x2(kind) => Item::new(Elem::FP4(kind), self.vectorization * 2, self.native),
118 Elem::FP6x2(kind) => Item::new(Elem::FP6(kind), self.vectorization * 2, self.native),
119 Elem::FP8x2(kind) => Item::new(Elem::FP8(kind), self.vectorization * 2, self.native),
120 Elem::F16x2 => Item::new(Elem::F16, self.vectorization * 2, self.native),
121 Elem::BF16x2 => Item::new(Elem::BF16, self.vectorization * 2, self.native),
122 _ => *self,
123 }
124 }
125}