cubecl_cpp/shared/
element.rs

1use cubecl_common::{e2m1, e2m1x2, e3m2, e5m2};
2use cubecl_core::{
3    ir::{BarrierLevel, ElemType, FloatKind, IntKind, UIntKind},
4    tf32,
5};
6use half::{bf16, f16};
7use std::fmt::Display;
8
9use super::Dialect;
10
11#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
12pub enum Elem<D: Dialect> {
13    TF32,
14    F32,
15    F64,
16    F16,
17    F16x2,
18    BF16,
19    BF16x2,
20    FP4(FP4Kind),
21    FP4x2(FP4Kind),
22    FP6(FP6Kind),
23    FP6x2(FP6Kind),
24    FP8(FP8Kind),
25    FP8x2(FP8Kind),
26    I8,
27    I16,
28    I32,
29    I64,
30    U8,
31    U16,
32    U32,
33    U64,
34    Bool,
35    Barrier(BarrierLevel),
36    Atomic(AtomicKind<D>),
37    _Dialect(std::marker::PhantomData<D>),
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
41pub enum FP4Kind {
42    E2M1,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
46pub enum FP6Kind {
47    E2M3,
48    E3M2,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
52pub enum FP8Kind {
53    E4M3,
54    E5M2,
55    UE8M0,
56}
57
58impl Display for FP4Kind {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        let name = match self {
61            FP4Kind::E2M1 => "e2m1",
62        };
63        f.write_str(name)
64    }
65}
66
67impl Display for FP6Kind {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        let name = match self {
70            FP6Kind::E2M3 => "e2m3",
71            FP6Kind::E3M2 => "e3m2",
72        };
73        f.write_str(name)
74    }
75}
76
77impl Display for FP8Kind {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        let name = match self {
80            FP8Kind::E4M3 => "e4m3",
81            FP8Kind::E5M2 => "e5m2",
82            FP8Kind::UE8M0 => "e8m0",
83        };
84        f.write_str(name)
85    }
86}
87
88impl<D: Dialect> Display for Elem<D> {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        D::compile_elem(f, self, false)
91    }
92}
93
94impl<D: Dialect> Elem<D> {
95    pub const fn size(&self) -> usize {
96        match self {
97            Elem::FP4(_) => core::mem::size_of::<e2m1>(),
98            Elem::FP4x2(_) => core::mem::size_of::<e2m1x2>(),
99            Elem::FP6(_) => core::mem::size_of::<e3m2>(),
100            Elem::FP6x2(_) => 2 * core::mem::size_of::<e3m2>(),
101            Elem::FP8(_) => core::mem::size_of::<e5m2>(),
102            Elem::FP8x2(_) => 2 * core::mem::size_of::<e5m2>(),
103            Elem::F16 => core::mem::size_of::<f16>(),
104            Elem::F16x2 => 2 * core::mem::size_of::<f16>(),
105            Elem::BF16x2 => 2 * core::mem::size_of::<bf16>(),
106            Elem::BF16 => core::mem::size_of::<bf16>(),
107            Elem::TF32 => core::mem::size_of::<tf32>(),
108            Elem::F32 => core::mem::size_of::<f32>(),
109            Elem::F64 => core::mem::size_of::<f64>(),
110            Elem::I8 => core::mem::size_of::<i8>(),
111            Elem::I16 => core::mem::size_of::<i16>(),
112            Elem::I32 => core::mem::size_of::<i32>(),
113            Elem::I64 => core::mem::size_of::<i64>(),
114            Elem::U8 => core::mem::size_of::<u8>(),
115            Elem::U16 => core::mem::size_of::<u16>(),
116            Elem::U32 => core::mem::size_of::<u32>(),
117            Elem::U64 => core::mem::size_of::<u64>(),
118            Elem::Bool => core::mem::size_of::<bool>(),
119            Elem::Barrier(_) => core::mem::size_of::<u64>(),
120            Elem::Atomic(AtomicKind::I32) => core::mem::size_of::<i32>(),
121            Elem::Atomic(AtomicKind::I64) => core::mem::size_of::<i64>(),
122            Elem::Atomic(AtomicKind::U32) => core::mem::size_of::<u32>(),
123            Elem::Atomic(AtomicKind::U64) => core::mem::size_of::<u64>(),
124            Elem::Atomic(AtomicKind::F16) => core::mem::size_of::<f16>(),
125            Elem::Atomic(AtomicKind::BF16) => core::mem::size_of::<bf16>(),
126            Elem::Atomic(AtomicKind::F32) => core::mem::size_of::<f32>(),
127            Elem::Atomic(AtomicKind::F64) => core::mem::size_of::<f64>(),
128            Elem::Atomic(AtomicKind::_Dialect(_)) => 0,
129            Elem::_Dialect(_) => 0,
130        }
131    }
132
133    pub const fn size_bits(&self) -> usize {
134        match self {
135            Elem::FP4(_) => 4,
136            other => other.size() * 8,
137        }
138    }
139
140    pub const fn unpacked(&self) -> Self {
141        match self {
142            Elem::FP4x2(ty) => Elem::FP4(*ty),
143            Elem::FP6x2(ty) => Elem::FP6(*ty),
144            Elem::FP8x2(ty) => Elem::FP8(*ty),
145            Elem::F16x2 => Elem::F16,
146            Elem::BF16x2 => Elem::BF16,
147            elem => *elem,
148        }
149    }
150
151    /// Get the number of values packed into a single storage element. (i.e. `f16x2 -> 2`)
152    pub const fn packing_factor(&self) -> usize {
153        match self {
154            Elem::FP4x2(_) | Elem::FP6x2(_) | Elem::FP8x2(_) | Elem::F16x2 | Elem::BF16x2 => 2,
155            _ => 1,
156        }
157    }
158
159    pub const fn ident(&self) -> &str {
160        match self {
161            Elem::FP4(_) => "fp4",
162            Elem::FP4x2(_) => "fp4x2",
163            Elem::FP6(_) => "fp6",
164            Elem::FP6x2(_) => "fp6x2",
165            Elem::FP8(_) => "fp8",
166            Elem::FP8x2(_) => "fp8x2",
167            Elem::F16 => "f16",
168            Elem::F16x2 => "f16x2",
169            Elem::BF16x2 => "bf16x2",
170            Elem::BF16 => "bf16",
171            Elem::TF32 => "tf32",
172            Elem::F32 => "f32",
173            Elem::F64 => "f64",
174            Elem::I8 => "i8",
175            Elem::I16 => "i16",
176            Elem::I32 => "i32",
177            Elem::I64 => "i64",
178            Elem::U8 => "u8",
179            Elem::U16 => "u16",
180            Elem::U32 => "u32",
181            Elem::U64 => "u64",
182            Elem::Bool => "bool",
183            Elem::Barrier(BarrierLevel::Cube) => "cuda::barrier<cuda::thread_scope_block>",
184            Elem::Barrier(BarrierLevel::Unit) => "cuda::barrier<cuda::thread_scope_thread>",
185            Elem::Atomic(_) => "atomic",
186            Elem::_Dialect(_) => "",
187        }
188    }
189}
190
191#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
192pub enum AtomicKind<D: Dialect> {
193    I32,
194    I64,
195    U32,
196    U64,
197    F16,
198    BF16,
199    F32,
200    F64,
201    /// Required to construct the inner `Elem` of the atomic value
202    _Dialect(std::marker::PhantomData<D>),
203}
204
205impl<D: Dialect> From<ElemType> for AtomicKind<D> {
206    fn from(value: ElemType) -> Self {
207        match value {
208            ElemType::Float(FloatKind::F16) => AtomicKind::F16,
209            ElemType::Float(FloatKind::BF16) => AtomicKind::BF16,
210            ElemType::Float(FloatKind::F32) => AtomicKind::F32,
211            ElemType::Float(FloatKind::F64) => AtomicKind::F64,
212            ElemType::Int(IntKind::I32) => AtomicKind::I32,
213            ElemType::Int(IntKind::I64) => AtomicKind::I64,
214            ElemType::UInt(UIntKind::U32) => AtomicKind::U32,
215            ElemType::UInt(UIntKind::U64) => AtomicKind::U64,
216            other => unimplemented!("Invalid atomic type: {other}"),
217        }
218    }
219}
220
221impl<D: Dialect> Display for AtomicKind<D> {
222    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223        D::compile_atomic_kind(f, self)
224    }
225}