cubecl_cpp/shared/
element.rs

1use cubecl_core::tf32;
2use half::{bf16, f16};
3use std::fmt::Display;
4
5use super::Dialect;
6
7#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
8pub enum Elem<D: Dialect> {
9    TF32,
10    F32,
11    F64,
12    F16,
13    F162,
14    BF16,
15    BF162,
16    I8,
17    I16,
18    I32,
19    I64,
20    U8,
21    U16,
22    U32,
23    U64,
24    Bool,
25    Atomic(AtomicKind<D>),
26    _Dialect(std::marker::PhantomData<D>),
27}
28
29impl<D: Dialect> Display for Elem<D> {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        D::compile_elem(f, self, false)
32    }
33}
34
35impl<D: Dialect> Elem<D> {
36    pub const fn size(&self) -> usize {
37        match self {
38            Elem::F16 => core::mem::size_of::<f16>(),
39            Elem::F162 => 2 * core::mem::size_of::<f16>(),
40            Elem::BF162 => 2 * core::mem::size_of::<bf16>(),
41            Elem::BF16 => core::mem::size_of::<bf16>(),
42            Elem::TF32 => core::mem::size_of::<tf32>(),
43            Elem::F32 => core::mem::size_of::<f32>(),
44            Elem::F64 => core::mem::size_of::<f64>(),
45            Elem::I8 => core::mem::size_of::<i8>(),
46            Elem::I16 => core::mem::size_of::<i16>(),
47            Elem::I32 => core::mem::size_of::<i32>(),
48            Elem::I64 => core::mem::size_of::<i64>(),
49            Elem::U8 => core::mem::size_of::<u8>(),
50            Elem::U16 => core::mem::size_of::<u16>(),
51            Elem::U32 => core::mem::size_of::<u32>(),
52            Elem::U64 => core::mem::size_of::<u64>(),
53            Elem::Bool => core::mem::size_of::<bool>(),
54            Elem::Atomic(AtomicKind::I32) => core::mem::size_of::<i32>(),
55            Elem::Atomic(AtomicKind::I64) => core::mem::size_of::<i64>(),
56            Elem::Atomic(AtomicKind::U32) => core::mem::size_of::<u32>(),
57            Elem::Atomic(AtomicKind::U64) => core::mem::size_of::<u64>(),
58            Elem::Atomic(AtomicKind::F16) => core::mem::size_of::<f16>(),
59            Elem::Atomic(AtomicKind::BF16) => core::mem::size_of::<bf16>(),
60            Elem::Atomic(AtomicKind::F32) => core::mem::size_of::<f32>(),
61            Elem::Atomic(AtomicKind::F64) => core::mem::size_of::<f64>(),
62            Elem::Atomic(AtomicKind::_Dialect(_)) => 0,
63            Elem::_Dialect(_) => 0,
64        }
65    }
66}
67
68#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
69pub enum AtomicKind<D: Dialect> {
70    I32,
71    I64,
72    U32,
73    U64,
74    F16,
75    BF16,
76    F32,
77    F64,
78    /// Required to construct the inner `Elem` of the atomic value
79    _Dialect(std::marker::PhantomData<D>),
80}
81
82impl<D: Dialect> Display for AtomicKind<D> {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        D::compile_atomic_kind(f, self)
85    }
86}