cubecl_cpp/shared/
item.rs

1use std::fmt::Display;
2
3use super::{Dialect, Elem};
4
5#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
6pub struct Item<D: Dialect> {
7    pub(crate) elem: Elem<D>,
8    pub(crate) vectorization: usize,
9    pub(crate) native: bool,
10}
11
12impl<D: Dialect> Display for Item<D> {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        D::compile_item(f, self)
15    }
16}
17
18impl<D: Dialect> Item<D> {
19    pub fn elem(&self) -> &Elem<D> {
20        &self.elem
21    }
22
23    pub fn new(elem: Elem<D>, vectorization: usize, native: bool) -> Self {
24        Self {
25            elem,
26            vectorization,
27            native,
28        }
29    }
30    pub fn scalar(elem: Elem<D>, native: bool) -> Self {
31        Self {
32            elem,
33            vectorization: 1,
34            native,
35        }
36    }
37
38    pub fn can_be_optimized(&self) -> bool {
39        D::item_can_be_optimized()
40    }
41
42    pub fn is_optimized(&self) -> bool {
43        matches!(self.elem, Elem::F162 | Elem::BF162)
44    }
45
46    pub fn optimized(&self) -> Item<D> {
47        if !self.can_be_optimized() || self.vectorization % 2 != 0 {
48            return *self;
49        }
50
51        match self.elem {
52            Elem::F16 => Item {
53                elem: Elem::F162,
54                vectorization: self.vectorization / 2,
55                native: self.native,
56            },
57            Elem::BF16 => Item {
58                elem: Elem::BF162,
59                vectorization: self.vectorization / 2,
60                native: self.native,
61            },
62            _ => *self,
63        }
64    }
65
66    pub fn de_optimized(&self) -> Self {
67        match self.elem {
68            Elem::F162 => Item::new(Elem::F16, self.vectorization * 2, self.native),
69            Elem::BF162 => Item::new(Elem::BF16, self.vectorization * 2, self.native),
70            _ => *self,
71        }
72    }
73}