cubecl_cpp/shared/
element.rs1use cubecl_core::tf32;
2use half::{bf16, f16};
3use std::fmt::Display;
4
5use super::Dialect;
6
7#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
8pub enum Elem<D: Dialect> {
9 TF32,
10 F32,
11 F64,
12 F16,
13 F162,
14 BF16,
15 BF162,
16 I8,
17 I16,
18 I32,
19 I64,
20 U8,
21 U16,
22 U32,
23 U64,
24 Bool,
25 Atomic(AtomicKind<D>),
26 _Dialect(std::marker::PhantomData<D>),
27}
28
29impl<D: Dialect> Display for Elem<D> {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 D::compile_elem(f, self, false)
32 }
33}
34
35impl<D: Dialect> Elem<D> {
36 pub const fn size(&self) -> usize {
37 match self {
38 Elem::F16 => core::mem::size_of::<f16>(),
39 Elem::F162 => 2 * core::mem::size_of::<f16>(),
40 Elem::BF162 => 2 * core::mem::size_of::<bf16>(),
41 Elem::BF16 => core::mem::size_of::<bf16>(),
42 Elem::TF32 => core::mem::size_of::<tf32>(),
43 Elem::F32 => core::mem::size_of::<f32>(),
44 Elem::F64 => core::mem::size_of::<f64>(),
45 Elem::I8 => core::mem::size_of::<i8>(),
46 Elem::I16 => core::mem::size_of::<i16>(),
47 Elem::I32 => core::mem::size_of::<i32>(),
48 Elem::I64 => core::mem::size_of::<i64>(),
49 Elem::U8 => core::mem::size_of::<u8>(),
50 Elem::U16 => core::mem::size_of::<u16>(),
51 Elem::U32 => core::mem::size_of::<u32>(),
52 Elem::U64 => core::mem::size_of::<u64>(),
53 Elem::Bool => core::mem::size_of::<bool>(),
54 Elem::Atomic(AtomicKind::I32) => core::mem::size_of::<i32>(),
55 Elem::Atomic(AtomicKind::I64) => core::mem::size_of::<i64>(),
56 Elem::Atomic(AtomicKind::U32) => core::mem::size_of::<u32>(),
57 Elem::Atomic(AtomicKind::U64) => core::mem::size_of::<u64>(),
58 Elem::Atomic(AtomicKind::F16) => core::mem::size_of::<f16>(),
59 Elem::Atomic(AtomicKind::BF16) => core::mem::size_of::<bf16>(),
60 Elem::Atomic(AtomicKind::F32) => core::mem::size_of::<f32>(),
61 Elem::Atomic(AtomicKind::F64) => core::mem::size_of::<f64>(),
62 Elem::Atomic(AtomicKind::_Dialect(_)) => 0,
63 Elem::_Dialect(_) => 0,
64 }
65 }
66}
67
68#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
69pub enum AtomicKind<D: Dialect> {
70 I32,
71 I64,
72 U32,
73 U64,
74 F16,
75 BF16,
76 F32,
77 F64,
78 _Dialect(std::marker::PhantomData<D>),
80}
81
82impl<D: Dialect> Display for AtomicKind<D> {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 D::compile_atomic_kind(f, self)
85 }
86}