cubecl_cpp/shared/
element.rs

1use cubecl_common::{e2m1x2, e3m2, e5m2};
2use cubecl_core::tf32;
3use half::{bf16, f16};
4use std::fmt::Display;
5
6use super::Dialect;
7
8#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
9pub enum Elem<D: Dialect> {
10    TF32,
11    F32,
12    F64,
13    F16,
14    F16x2,
15    BF16,
16    BF16x2,
17    FP4(FP4Kind),
18    FP4x2(FP4Kind),
19    FP6(FP6Kind),
20    FP6x2(FP6Kind),
21    FP8(FP8Kind),
22    FP8x2(FP8Kind),
23    I8,
24    I16,
25    I32,
26    I64,
27    U8,
28    U16,
29    U32,
30    U64,
31    Bool,
32    Atomic(AtomicKind<D>),
33    _Dialect(std::marker::PhantomData<D>),
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
37pub enum FP4Kind {
38    E2M1,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
42pub enum FP6Kind {
43    E2M3,
44    E3M2,
45}
46
47#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
48pub enum FP8Kind {
49    E4M3,
50    E5M2,
51    UE8M0,
52}
53
54impl Display for FP4Kind {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        let name = match self {
57            FP4Kind::E2M1 => "e2m1",
58        };
59        f.write_str(name)
60    }
61}
62
63impl Display for FP6Kind {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        let name = match self {
66            FP6Kind::E2M3 => "e2m3",
67            FP6Kind::E3M2 => "e3m2",
68        };
69        f.write_str(name)
70    }
71}
72
73impl Display for FP8Kind {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        let name = match self {
76            FP8Kind::E4M3 => "e4m3",
77            FP8Kind::E5M2 => "e5m2",
78            FP8Kind::UE8M0 => "e8m0",
79        };
80        f.write_str(name)
81    }
82}
83
84impl<D: Dialect> Display for Elem<D> {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        D::compile_elem(f, self, false)
87    }
88}
89
90impl<D: Dialect> Elem<D> {
91    pub const fn size(&self) -> usize {
92        match self {
93            Elem::FP4(_) => panic!("Can't get byte size of sub-byte type"),
94            Elem::FP4x2(_) => core::mem::size_of::<e2m1x2>(),
95            Elem::FP6(_) => core::mem::size_of::<e3m2>(),
96            Elem::FP6x2(_) => 2 * core::mem::size_of::<e3m2>(),
97            Elem::FP8(_) => core::mem::size_of::<e5m2>(),
98            Elem::FP8x2(_) => 2 * core::mem::size_of::<e5m2>(),
99            Elem::F16 => core::mem::size_of::<f16>(),
100            Elem::F16x2 => 2 * core::mem::size_of::<f16>(),
101            Elem::BF16x2 => 2 * core::mem::size_of::<bf16>(),
102            Elem::BF16 => core::mem::size_of::<bf16>(),
103            Elem::TF32 => core::mem::size_of::<tf32>(),
104            Elem::F32 => core::mem::size_of::<f32>(),
105            Elem::F64 => core::mem::size_of::<f64>(),
106            Elem::I8 => core::mem::size_of::<i8>(),
107            Elem::I16 => core::mem::size_of::<i16>(),
108            Elem::I32 => core::mem::size_of::<i32>(),
109            Elem::I64 => core::mem::size_of::<i64>(),
110            Elem::U8 => core::mem::size_of::<u8>(),
111            Elem::U16 => core::mem::size_of::<u16>(),
112            Elem::U32 => core::mem::size_of::<u32>(),
113            Elem::U64 => core::mem::size_of::<u64>(),
114            Elem::Bool => core::mem::size_of::<bool>(),
115            Elem::Atomic(AtomicKind::I32) => core::mem::size_of::<i32>(),
116            Elem::Atomic(AtomicKind::I64) => core::mem::size_of::<i64>(),
117            Elem::Atomic(AtomicKind::U32) => core::mem::size_of::<u32>(),
118            Elem::Atomic(AtomicKind::U64) => core::mem::size_of::<u64>(),
119            Elem::Atomic(AtomicKind::F16) => core::mem::size_of::<f16>(),
120            Elem::Atomic(AtomicKind::BF16) => core::mem::size_of::<bf16>(),
121            Elem::Atomic(AtomicKind::F32) => core::mem::size_of::<f32>(),
122            Elem::Atomic(AtomicKind::F64) => core::mem::size_of::<f64>(),
123            Elem::Atomic(AtomicKind::_Dialect(_)) => 0,
124            Elem::_Dialect(_) => 0,
125        }
126    }
127
128    pub const fn size_bits(&self) -> usize {
129        match self {
130            Elem::FP4(_) => 4,
131            other => other.size() * 8,
132        }
133    }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
137pub enum AtomicKind<D: Dialect> {
138    I32,
139    I64,
140    U32,
141    U64,
142    F16,
143    BF16,
144    F32,
145    F64,
146    /// Required to construct the inner `Elem` of the atomic value
147    _Dialect(std::marker::PhantomData<D>),
148}
149
150impl<D: Dialect> Display for AtomicKind<D> {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        D::compile_atomic_kind(f, self)
153    }
154}