Skip to main content

cubecl_cpp/shared/
element.rs

1use cubecl_common::{e2m1, e2m1x2, e3m2, e5m2};
2use cubecl_core::{
3    ir::{BarrierLevel, ElemType, FloatKind, IntKind, UIntKind},
4    tf32,
5};
6use half::{bf16, f16};
7use std::fmt::Display;
8
9use super::Dialect;
10
11#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
12pub enum Elem<D: Dialect> {
13    TF32,
14    F32,
15    F64,
16    F16,
17    F16x2,
18    BF16,
19    BF16x2,
20    FP4(FP4Kind),
21    FP4x2(FP4Kind),
22    FP6(FP6Kind),
23    FP6x2(FP6Kind),
24    FP8(FP8Kind),
25    FP8x2(FP8Kind),
26    I8,
27    I16,
28    I32,
29    I64,
30    U8,
31    U16,
32    U32,
33    U64,
34    Bool,
35    Barrier(BarrierLevel),
36    Atomic(AtomicKind<D>),
37    _Dialect(std::marker::PhantomData<D>),
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
41pub enum FP4Kind {
42    E2M1,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
46pub enum FP6Kind {
47    E2M3,
48    E3M2,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
52pub enum FP8Kind {
53    E4M3,
54    E5M2,
55    UE8M0,
56}
57
58impl Display for FP4Kind {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        let name = match self {
61            FP4Kind::E2M1 => "e2m1",
62        };
63        f.write_str(name)
64    }
65}
66
67impl Display for FP6Kind {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        let name = match self {
70            FP6Kind::E2M3 => "e2m3",
71            FP6Kind::E3M2 => "e3m2",
72        };
73        f.write_str(name)
74    }
75}
76
77impl Display for FP8Kind {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        let name = match self {
80            FP8Kind::E4M3 => "e4m3",
81            FP8Kind::E5M2 => "e5m2",
82            FP8Kind::UE8M0 => "e8m0",
83        };
84        f.write_str(name)
85    }
86}
87
88impl<D: Dialect> Display for Elem<D> {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        D::compile_elem(f, self, false)
91    }
92}
93
94impl<D: Dialect> Elem<D> {
95    pub const fn size(&self) -> usize {
96        match self {
97            Elem::FP4(_) => core::mem::size_of::<e2m1>(),
98            Elem::FP4x2(_) => core::mem::size_of::<e2m1x2>(),
99            Elem::FP6(_) => core::mem::size_of::<e3m2>(),
100            Elem::FP6x2(_) => 2 * core::mem::size_of::<e3m2>(),
101            Elem::FP8(_) => core::mem::size_of::<e5m2>(),
102            Elem::FP8x2(_) => 2 * core::mem::size_of::<e5m2>(),
103            Elem::F16 => core::mem::size_of::<f16>(),
104            Elem::F16x2 => 2 * core::mem::size_of::<f16>(),
105            Elem::BF16x2 => 2 * core::mem::size_of::<bf16>(),
106            Elem::BF16 => core::mem::size_of::<bf16>(),
107            Elem::TF32 => core::mem::size_of::<tf32>(),
108            Elem::F32 => core::mem::size_of::<f32>(),
109            Elem::F64 => core::mem::size_of::<f64>(),
110            Elem::I8 => core::mem::size_of::<i8>(),
111            Elem::I16 => core::mem::size_of::<i16>(),
112            Elem::I32 => core::mem::size_of::<i32>(),
113            Elem::I64 => core::mem::size_of::<i64>(),
114            Elem::U8 => core::mem::size_of::<u8>(),
115            Elem::U16 => core::mem::size_of::<u16>(),
116            Elem::U32 => core::mem::size_of::<u32>(),
117            Elem::U64 => core::mem::size_of::<u64>(),
118            Elem::Bool => core::mem::size_of::<bool>(),
119            Elem::Barrier(_) => core::mem::size_of::<u64>(),
120            Elem::Atomic(AtomicKind::I32) => core::mem::size_of::<i32>(),
121            Elem::Atomic(AtomicKind::I64) => core::mem::size_of::<i64>(),
122            Elem::Atomic(AtomicKind::U32) => core::mem::size_of::<u32>(),
123            Elem::Atomic(AtomicKind::U64) => core::mem::size_of::<u64>(),
124            Elem::Atomic(AtomicKind::F16) => core::mem::size_of::<f16>(),
125            Elem::Atomic(AtomicKind::F16x2) => core::mem::size_of::<f32>(),
126            Elem::Atomic(AtomicKind::BF16) => core::mem::size_of::<bf16>(),
127            Elem::Atomic(AtomicKind::BF16x2) => core::mem::size_of::<f32>(),
128            Elem::Atomic(AtomicKind::F32) => core::mem::size_of::<f32>(),
129            Elem::Atomic(AtomicKind::F64) => core::mem::size_of::<f64>(),
130            Elem::Atomic(AtomicKind::_Dialect(_)) => 0,
131            Elem::_Dialect(_) => 0,
132        }
133    }
134
135    pub const fn size_bits(&self) -> usize {
136        match self {
137            Elem::FP4(_) => 4,
138            other => other.size() * 8,
139        }
140    }
141
142    pub const fn unpacked(&self) -> Self {
143        match self {
144            Elem::FP4x2(ty) => Elem::FP4(*ty),
145            Elem::FP6x2(ty) => Elem::FP6(*ty),
146            Elem::FP8x2(ty) => Elem::FP8(*ty),
147            Elem::F16x2 => Elem::F16,
148            Elem::BF16x2 => Elem::BF16,
149            elem => *elem,
150        }
151    }
152
153    /// Get the number of values packed into a single storage element. (i.e. `f16x2 -> 2`)
154    pub const fn packing_factor(&self) -> usize {
155        match self {
156            Elem::FP4x2(_) | Elem::FP6x2(_) | Elem::FP8x2(_) | Elem::F16x2 | Elem::BF16x2 => 2,
157            _ => 1,
158        }
159    }
160
161    pub const fn ident(&self) -> &str {
162        match self {
163            Elem::FP4(_) => "fp4",
164            Elem::FP4x2(_) => "fp4x2",
165            Elem::FP6(_) => "fp6",
166            Elem::FP6x2(_) => "fp6x2",
167            Elem::FP8(_) => "fp8",
168            Elem::FP8x2(_) => "fp8x2",
169            Elem::F16 => "f16",
170            Elem::F16x2 => "f16x2",
171            Elem::BF16x2 => "bf16x2",
172            Elem::BF16 => "bf16",
173            Elem::TF32 => "tf32",
174            Elem::F32 => "f32",
175            Elem::F64 => "f64",
176            Elem::I8 => "i8",
177            Elem::I16 => "i16",
178            Elem::I32 => "i32",
179            Elem::I64 => "i64",
180            Elem::U8 => "u8",
181            Elem::U16 => "u16",
182            Elem::U32 => "u32",
183            Elem::U64 => "u64",
184            Elem::Bool => "bool",
185            Elem::Barrier(BarrierLevel::Cube) => "cuda::barrier<cuda::thread_scope_block>",
186            Elem::Barrier(BarrierLevel::Unit) => "cuda::barrier<cuda::thread_scope_thread>",
187            Elem::Atomic(_) => "atomic",
188            Elem::_Dialect(_) => "",
189        }
190    }
191}
192
193#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
194pub enum AtomicKind<D: Dialect> {
195    I32,
196    I64,
197    U32,
198    U64,
199    F16,
200    F16x2,
201    BF16,
202    BF16x2,
203    F32,
204    F64,
205    /// Required to construct the inner `Elem` of the atomic value
206    _Dialect(std::marker::PhantomData<D>),
207}
208
209impl<D: Dialect> From<ElemType> for AtomicKind<D> {
210    fn from(value: ElemType) -> Self {
211        match value {
212            ElemType::Float(FloatKind::F16) => AtomicKind::F16,
213            ElemType::Float(FloatKind::BF16) => AtomicKind::BF16,
214            ElemType::Float(FloatKind::F32) => AtomicKind::F32,
215            ElemType::Float(FloatKind::F64) => AtomicKind::F64,
216            ElemType::Int(IntKind::I32) => AtomicKind::I32,
217            ElemType::Int(IntKind::I64) => AtomicKind::I64,
218            ElemType::UInt(UIntKind::U32) => AtomicKind::U32,
219            ElemType::UInt(UIntKind::U64) => AtomicKind::U64,
220            other => unimplemented!("Invalid atomic type: {other}"),
221        }
222    }
223}
224
225impl<D: Dialect> AtomicKind<D> {
226    pub fn as_elem(self) -> Elem<D> {
227        match self {
228            AtomicKind::I32 => Elem::I32,
229            AtomicKind::I64 => Elem::I64,
230            AtomicKind::U32 => Elem::U32,
231            AtomicKind::U64 => Elem::U64,
232            AtomicKind::F16 => Elem::F16,
233            AtomicKind::F16x2 => Elem::F16x2,
234            AtomicKind::BF16 => Elem::BF16,
235            AtomicKind::BF16x2 => Elem::BF16x2,
236            AtomicKind::F32 => Elem::F32,
237            AtomicKind::F64 => Elem::F64,
238            AtomicKind::_Dialect(_) => unreachable!(),
239        }
240    }
241}
242
243impl<D: Dialect> Display for AtomicKind<D> {
244    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245        D::compile_atomic_kind(f, self)
246    }
247}