Skip to main content

cubecl_cpp/shared/
item.rs

1use std::fmt::Display;
2
3use crate::shared::AtomicKind;
4
5use super::{Dialect, Elem};
6
7#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
8pub struct Item<D: Dialect> {
9    pub(crate) elem: Elem<D>,
10    pub(crate) vectorization: usize,
11    pub(crate) native: bool,
12}
13
14impl<D: Dialect> Display for Item<D> {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        D::compile_item(f, self)
17    }
18}
19
20impl<D: Dialect> Item<D> {
21    pub fn elem(&self) -> &Elem<D> {
22        &self.elem
23    }
24
25    pub const fn size(&self) -> usize {
26        self.elem.size() * self.vectorization
27    }
28
29    pub fn new(elem: Elem<D>, vectorization: usize, native: bool) -> Self {
30        Self {
31            elem,
32            vectorization,
33            native,
34        }
35    }
36    pub fn scalar(elem: Elem<D>, native: bool) -> Self {
37        Self {
38            elem,
39            vectorization: 1,
40            native,
41        }
42    }
43
44    pub fn can_be_optimized(&self) -> bool {
45        D::item_can_be_optimized()
46    }
47
48    pub fn is_optimized(&self) -> bool {
49        matches!(
50            self.elem,
51            Elem::F16x2
52                | Elem::BF16x2
53                | Elem::Atomic(AtomicKind::F16x2)
54                | Elem::Atomic(AtomicKind::BF16x2)
55                | Elem::FP4x2(_)
56                | Elem::FP6x2(_)
57                | Elem::FP8x2(_)
58        )
59    }
60
61    pub fn is_atomic(&self) -> bool {
62        matches!(self.elem, Elem::Atomic(_))
63    }
64
65    pub fn optimized(&self) -> Item<D> {
66        if !self.can_be_optimized() || !self.vectorization.is_multiple_of(2) {
67            return *self;
68        }
69
70        match self.elem {
71            Elem::F16 => Item {
72                elem: Elem::F16x2,
73                vectorization: self.vectorization / 2,
74                native: self.native,
75            },
76            Elem::Atomic(AtomicKind::F16) => Item {
77                elem: Elem::Atomic(AtomicKind::F16x2),
78                vectorization: self.vectorization / 2,
79                native: self.native,
80            },
81            Elem::BF16 => Item {
82                elem: Elem::BF16x2,
83                vectorization: self.vectorization / 2,
84                native: self.native,
85            },
86            Elem::Atomic(AtomicKind::BF16) => Item {
87                elem: Elem::Atomic(AtomicKind::BF16x2),
88                vectorization: self.vectorization / 2,
89                native: self.native,
90            },
91            Elem::FP4(kind) => Item {
92                elem: Elem::FP4x2(kind),
93                vectorization: self.vectorization / 2,
94                native: self.native,
95            },
96            Elem::FP6(kind) => Item {
97                elem: Elem::FP6x2(kind),
98                vectorization: self.vectorization / 2,
99                native: self.native,
100            },
101            Elem::FP8(kind) => Item {
102                elem: Elem::FP8x2(kind),
103                vectorization: self.vectorization / 2,
104                native: self.native,
105            },
106            _ => *self,
107        }
108    }
109
110    /// Get the number of values packed into a single storage element. (i.e. `f16x2 -> 2`)
111    pub fn packing_factor(&self) -> usize {
112        self.elem.packing_factor()
113    }
114
115    pub fn de_optimized(&self) -> Self {
116        match self.elem {
117            Elem::FP4x2(kind) => Item::new(Elem::FP4(kind), self.vectorization * 2, self.native),
118            Elem::FP6x2(kind) => Item::new(Elem::FP6(kind), self.vectorization * 2, self.native),
119            Elem::FP8x2(kind) => Item::new(Elem::FP8(kind), self.vectorization * 2, self.native),
120            Elem::F16x2 => Item::new(Elem::F16, self.vectorization * 2, self.native),
121            Elem::BF16x2 => Item::new(Elem::BF16, self.vectorization * 2, self.native),
122            _ => *self,
123        }
124    }
125}