cubecl_cpp/shared/
item.rs

1use std::fmt::Display;
2
3use super::{Dialect, Elem};
4
5#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
6pub struct Item<D: Dialect> {
7    pub(crate) elem: Elem<D>,
8    pub(crate) vectorization: usize,
9    pub(crate) native: bool,
10}
11
12impl<D: Dialect> Display for Item<D> {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        D::compile_item(f, self)
15    }
16}
17
18impl<D: Dialect> Item<D> {
19    pub fn elem(&self) -> &Elem<D> {
20        &self.elem
21    }
22
23    pub const fn size(&self) -> usize {
24        self.elem.size() * self.vectorization
25    }
26
27    pub fn new(elem: Elem<D>, vectorization: usize, native: bool) -> Self {
28        Self {
29            elem,
30            vectorization,
31            native,
32        }
33    }
34    pub fn scalar(elem: Elem<D>, native: bool) -> Self {
35        Self {
36            elem,
37            vectorization: 1,
38            native,
39        }
40    }
41
42    pub fn can_be_optimized(&self) -> bool {
43        D::item_can_be_optimized()
44    }
45
46    pub fn is_optimized(&self) -> bool {
47        matches!(
48            self.elem,
49            Elem::F16x2 | Elem::BF16x2 | Elem::FP4x2(_) | Elem::FP6x2(_) | Elem::FP8x2(_)
50        )
51    }
52
53    pub fn optimized(&self) -> Item<D> {
54        if !self.can_be_optimized() || self.vectorization % 2 != 0 {
55            return *self;
56        }
57
58        match self.elem {
59            Elem::F16 => Item {
60                elem: Elem::F16x2,
61                vectorization: self.vectorization / 2,
62                native: self.native,
63            },
64            Elem::BF16 => Item {
65                elem: Elem::BF16x2,
66                vectorization: self.vectorization / 2,
67                native: self.native,
68            },
69            Elem::FP4(kind) => Item {
70                elem: Elem::FP4x2(kind),
71                vectorization: self.vectorization / 2,
72                native: self.native,
73            },
74            Elem::FP6(kind) => Item {
75                elem: Elem::FP6x2(kind),
76                vectorization: self.vectorization / 2,
77                native: self.native,
78            },
79            Elem::FP8(kind) => Item {
80                elem: Elem::FP8x2(kind),
81                vectorization: self.vectorization / 2,
82                native: self.native,
83            },
84            _ => *self,
85        }
86    }
87
88    pub fn de_optimized(&self) -> Self {
89        match self.elem {
90            Elem::FP4x2(kind) => Item::new(Elem::FP4(kind), self.vectorization * 2, self.native),
91            Elem::FP6x2(kind) => Item::new(Elem::FP6(kind), self.vectorization * 2, self.native),
92            Elem::FP8x2(kind) => Item::new(Elem::FP8(kind), self.vectorization * 2, self.native),
93            Elem::F16x2 => Item::new(Elem::F16, self.vectorization * 2, self.native),
94            Elem::BF16x2 => Item::new(Elem::BF16, self.vectorization * 2, self.native),
95            _ => *self,
96        }
97    }
98}