cubecl_cpp/shared/
item.rs1use std::fmt::Display;
2
3use super::{Dialect, Elem};
4
5#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
6pub struct Item<D: Dialect> {
7 pub(crate) elem: Elem<D>,
8 pub(crate) vectorization: usize,
9 pub(crate) native: bool,
10}
11
12impl<D: Dialect> Display for Item<D> {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 D::compile_item(f, self)
15 }
16}
17
18impl<D: Dialect> Item<D> {
19 pub fn elem(&self) -> &Elem<D> {
20 &self.elem
21 }
22
23 pub fn new(elem: Elem<D>, vectorization: usize, native: bool) -> Self {
24 Self {
25 elem,
26 vectorization,
27 native,
28 }
29 }
30 pub fn scalar(elem: Elem<D>, native: bool) -> Self {
31 Self {
32 elem,
33 vectorization: 1,
34 native,
35 }
36 }
37
38 pub fn can_be_optimized(&self) -> bool {
39 D::item_can_be_optimized()
40 }
41
42 pub fn is_optimized(&self) -> bool {
43 matches!(self.elem, Elem::F162 | Elem::BF162)
44 }
45
46 pub fn optimized(&self) -> Item<D> {
47 if !self.can_be_optimized() || self.vectorization % 2 != 0 {
48 return *self;
49 }
50
51 match self.elem {
52 Elem::F16 => Item {
53 elem: Elem::F162,
54 vectorization: self.vectorization / 2,
55 native: self.native,
56 },
57 Elem::BF16 => Item {
58 elem: Elem::BF162,
59 vectorization: self.vectorization / 2,
60 native: self.native,
61 },
62 _ => *self,
63 }
64 }
65
66 pub fn de_optimized(&self) -> Self {
67 match self.elem {
68 Elem::F162 => Item::new(Elem::F16, self.vectorization * 2, self.native),
69 Elem::BF162 => Item::new(Elem::BF16, self.vectorization * 2, self.native),
70 _ => *self,
71 }
72 }
73}