cubecl_cpp/shared/
element.rs

1use cubecl_common::{e2m1, e2m1x2, e3m2, e5m2};
2use cubecl_core::{
3    ir::{ElemType, FloatKind, IntKind, UIntKind},
4    tf32,
5};
6use half::{bf16, f16};
7use std::fmt::Display;
8
9use super::Dialect;
10
11#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
12pub enum Elem<D: Dialect> {
13    TF32,
14    F32,
15    F64,
16    F16,
17    F16x2,
18    BF16,
19    BF16x2,
20    FP4(FP4Kind),
21    FP4x2(FP4Kind),
22    FP6(FP6Kind),
23    FP6x2(FP6Kind),
24    FP8(FP8Kind),
25    FP8x2(FP8Kind),
26    I8,
27    I16,
28    I32,
29    I64,
30    U8,
31    U16,
32    U32,
33    U64,
34    Bool,
35    Atomic(AtomicKind<D>),
36    _Dialect(std::marker::PhantomData<D>),
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
40pub enum FP4Kind {
41    E2M1,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
45pub enum FP6Kind {
46    E2M3,
47    E3M2,
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
51pub enum FP8Kind {
52    E4M3,
53    E5M2,
54    UE8M0,
55}
56
57impl Display for FP4Kind {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        let name = match self {
60            FP4Kind::E2M1 => "e2m1",
61        };
62        f.write_str(name)
63    }
64}
65
66impl Display for FP6Kind {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        let name = match self {
69            FP6Kind::E2M3 => "e2m3",
70            FP6Kind::E3M2 => "e3m2",
71        };
72        f.write_str(name)
73    }
74}
75
76impl Display for FP8Kind {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        let name = match self {
79            FP8Kind::E4M3 => "e4m3",
80            FP8Kind::E5M2 => "e5m2",
81            FP8Kind::UE8M0 => "e8m0",
82        };
83        f.write_str(name)
84    }
85}
86
87impl<D: Dialect> Display for Elem<D> {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        D::compile_elem(f, self, false)
90    }
91}
92
93impl<D: Dialect> Elem<D> {
94    pub const fn size(&self) -> usize {
95        match self {
96            Elem::FP4(_) => core::mem::size_of::<e2m1>(),
97            Elem::FP4x2(_) => core::mem::size_of::<e2m1x2>(),
98            Elem::FP6(_) => core::mem::size_of::<e3m2>(),
99            Elem::FP6x2(_) => 2 * core::mem::size_of::<e3m2>(),
100            Elem::FP8(_) => core::mem::size_of::<e5m2>(),
101            Elem::FP8x2(_) => 2 * core::mem::size_of::<e5m2>(),
102            Elem::F16 => core::mem::size_of::<f16>(),
103            Elem::F16x2 => 2 * core::mem::size_of::<f16>(),
104            Elem::BF16x2 => 2 * core::mem::size_of::<bf16>(),
105            Elem::BF16 => core::mem::size_of::<bf16>(),
106            Elem::TF32 => core::mem::size_of::<tf32>(),
107            Elem::F32 => core::mem::size_of::<f32>(),
108            Elem::F64 => core::mem::size_of::<f64>(),
109            Elem::I8 => core::mem::size_of::<i8>(),
110            Elem::I16 => core::mem::size_of::<i16>(),
111            Elem::I32 => core::mem::size_of::<i32>(),
112            Elem::I64 => core::mem::size_of::<i64>(),
113            Elem::U8 => core::mem::size_of::<u8>(),
114            Elem::U16 => core::mem::size_of::<u16>(),
115            Elem::U32 => core::mem::size_of::<u32>(),
116            Elem::U64 => core::mem::size_of::<u64>(),
117            Elem::Bool => core::mem::size_of::<bool>(),
118            Elem::Atomic(AtomicKind::I32) => core::mem::size_of::<i32>(),
119            Elem::Atomic(AtomicKind::I64) => core::mem::size_of::<i64>(),
120            Elem::Atomic(AtomicKind::U32) => core::mem::size_of::<u32>(),
121            Elem::Atomic(AtomicKind::U64) => core::mem::size_of::<u64>(),
122            Elem::Atomic(AtomicKind::F16) => core::mem::size_of::<f16>(),
123            Elem::Atomic(AtomicKind::BF16) => core::mem::size_of::<bf16>(),
124            Elem::Atomic(AtomicKind::F32) => core::mem::size_of::<f32>(),
125            Elem::Atomic(AtomicKind::F64) => core::mem::size_of::<f64>(),
126            Elem::Atomic(AtomicKind::_Dialect(_)) => 0,
127            Elem::_Dialect(_) => 0,
128        }
129    }
130
131    pub const fn size_bits(&self) -> usize {
132        match self {
133            Elem::FP4(_) => 4,
134            other => other.size() * 8,
135        }
136    }
137
138    pub const fn unpacked(&self) -> Self {
139        match self {
140            Elem::FP4x2(ty) => Elem::FP4(*ty),
141            Elem::FP6x2(ty) => Elem::FP6(*ty),
142            Elem::FP8x2(ty) => Elem::FP8(*ty),
143            Elem::F16x2 => Elem::F16,
144            Elem::BF16x2 => Elem::BF16,
145            elem => *elem,
146        }
147    }
148
149    /// Get the number of values packed into a single storage element. (i.e. `f16x2 -> 2`)
150    pub const fn packing_factor(&self) -> usize {
151        match self {
152            Elem::FP4x2(_) | Elem::FP6x2(_) | Elem::FP8x2(_) | Elem::F16x2 | Elem::BF16x2 => 2,
153            _ => 1,
154        }
155    }
156
157    pub const fn ident(&self) -> &str {
158        match self {
159            Elem::FP4(_) => "fp4",
160            Elem::FP4x2(_) => "fp4x2",
161            Elem::FP6(_) => "fp6",
162            Elem::FP6x2(_) => "fp6x2",
163            Elem::FP8(_) => "fp8",
164            Elem::FP8x2(_) => "fp8x2",
165            Elem::F16 => "f16",
166            Elem::F16x2 => "f16x2",
167            Elem::BF16x2 => "bf16x2",
168            Elem::BF16 => "bf16",
169            Elem::TF32 => "tf32",
170            Elem::F32 => "f32",
171            Elem::F64 => "f64",
172            Elem::I8 => "i8",
173            Elem::I16 => "i16",
174            Elem::I32 => "i32",
175            Elem::I64 => "i64",
176            Elem::U8 => "u8",
177            Elem::U16 => "u16",
178            Elem::U32 => "u32",
179            Elem::U64 => "u64",
180            Elem::Bool => "bool",
181            Elem::Atomic(_) => "atomic",
182            Elem::_Dialect(_) => "",
183        }
184    }
185}
186
187#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
188pub enum AtomicKind<D: Dialect> {
189    I32,
190    I64,
191    U32,
192    U64,
193    F16,
194    BF16,
195    F32,
196    F64,
197    /// Required to construct the inner `Elem` of the atomic value
198    _Dialect(std::marker::PhantomData<D>),
199}
200
201impl<D: Dialect> From<ElemType> for AtomicKind<D> {
202    fn from(value: ElemType) -> Self {
203        match value {
204            ElemType::Float(FloatKind::F16) => AtomicKind::F16,
205            ElemType::Float(FloatKind::BF16) => AtomicKind::BF16,
206            ElemType::Float(FloatKind::F32) => AtomicKind::F32,
207            ElemType::Float(FloatKind::F64) => AtomicKind::F64,
208            ElemType::Int(IntKind::I32) => AtomicKind::I32,
209            ElemType::Int(IntKind::I64) => AtomicKind::I64,
210            ElemType::UInt(UIntKind::U32) => AtomicKind::U32,
211            ElemType::UInt(UIntKind::U64) => AtomicKind::U64,
212            other => unimplemented!("Invalid atomic type: {other}"),
213        }
214    }
215}
216
217impl<D: Dialect> Display for AtomicKind<D> {
218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219        D::compile_atomic_kind(f, self)
220    }
221}