cubecl_cpp/shared/
item.rs1use std::fmt::Display;
2
3use super::{Dialect, Elem};
4
5#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
6pub struct Item<D: Dialect> {
7 pub(crate) elem: Elem<D>,
8 pub(crate) vectorization: usize,
9 pub(crate) native: bool,
10}
11
12impl<D: Dialect> Display for Item<D> {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 D::compile_item(f, self)
15 }
16}
17
18impl<D: Dialect> Item<D> {
19 pub fn elem(&self) -> &Elem<D> {
20 &self.elem
21 }
22
23 pub const fn size(&self) -> usize {
24 self.elem.size() * self.vectorization
25 }
26
27 pub fn new(elem: Elem<D>, vectorization: usize, native: bool) -> Self {
28 Self {
29 elem,
30 vectorization,
31 native,
32 }
33 }
34 pub fn scalar(elem: Elem<D>, native: bool) -> Self {
35 Self {
36 elem,
37 vectorization: 1,
38 native,
39 }
40 }
41
42 pub fn can_be_optimized(&self) -> bool {
43 D::item_can_be_optimized()
44 }
45
46 pub fn is_optimized(&self) -> bool {
47 matches!(
48 self.elem,
49 Elem::F16x2 | Elem::BF16x2 | Elem::FP4x2(_) | Elem::FP6x2(_) | Elem::FP8x2(_)
50 )
51 }
52
53 pub fn optimized(&self) -> Item<D> {
54 if !self.can_be_optimized() || self.vectorization % 2 != 0 {
55 return *self;
56 }
57
58 match self.elem {
59 Elem::F16 => Item {
60 elem: Elem::F16x2,
61 vectorization: self.vectorization / 2,
62 native: self.native,
63 },
64 Elem::BF16 => Item {
65 elem: Elem::BF16x2,
66 vectorization: self.vectorization / 2,
67 native: self.native,
68 },
69 Elem::FP4(kind) => Item {
70 elem: Elem::FP4x2(kind),
71 vectorization: self.vectorization / 2,
72 native: self.native,
73 },
74 Elem::FP6(kind) => Item {
75 elem: Elem::FP6x2(kind),
76 vectorization: self.vectorization / 2,
77 native: self.native,
78 },
79 Elem::FP8(kind) => Item {
80 elem: Elem::FP8x2(kind),
81 vectorization: self.vectorization / 2,
82 native: self.native,
83 },
84 _ => *self,
85 }
86 }
87
88 pub fn de_optimized(&self) -> Self {
89 match self.elem {
90 Elem::FP4x2(kind) => Item::new(Elem::FP4(kind), self.vectorization * 2, self.native),
91 Elem::FP6x2(kind) => Item::new(Elem::FP6(kind), self.vectorization * 2, self.native),
92 Elem::FP8x2(kind) => Item::new(Elem::FP8(kind), self.vectorization * 2, self.native),
93 Elem::F16x2 => Item::new(Elem::F16, self.vectorization * 2, self.native),
94 Elem::BF16x2 => Item::new(Elem::BF16, self.vectorization * 2, self.native),
95 _ => *self,
96 }
97 }
98}