1use serde::{Deserialize, Serialize};
4
5use crate::tensor::quantization::{QuantScheme, QuantStore, QuantValue};
6use crate::{bf16, f16};
7
8#[allow(missing_docs)]
9#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
10pub enum DType {
11 F64,
12 F32,
13 Flex32,
14 F16,
15 BF16,
16 I64,
17 I32,
18 I16,
19 I8,
20 U64,
21 U32,
22 U16,
23 U8,
24 Bool,
25 QFloat(QuantScheme),
26}
27
28#[cfg(feature = "cubecl")]
29impl From<cubecl::ir::ElemType> for DType {
30 fn from(value: cubecl::ir::ElemType) -> Self {
31 match value {
32 cubecl::ir::ElemType::Float(float_kind) => match float_kind {
33 cubecl::ir::FloatKind::F16 => DType::F16,
34 cubecl::ir::FloatKind::BF16 => DType::BF16,
35 cubecl::ir::FloatKind::Flex32 => DType::Flex32,
36 cubecl::ir::FloatKind::F32 => DType::F32,
37 cubecl::ir::FloatKind::F64 => DType::F64,
38 cubecl::ir::FloatKind::TF32 => panic!("Not a valid DType for tensors."),
39 cubecl::ir::FloatKind::E2M1
40 | cubecl::ir::FloatKind::E2M3
41 | cubecl::ir::FloatKind::E3M2
42 | cubecl::ir::FloatKind::E4M3
43 | cubecl::ir::FloatKind::E5M2
44 | cubecl::ir::FloatKind::UE8M0 => {
45 unimplemented!("Not yet supported, will be used for quantization")
46 }
47 },
48 cubecl::ir::ElemType::Int(int_kind) => match int_kind {
49 cubecl::ir::IntKind::I8 => DType::I8,
50 cubecl::ir::IntKind::I16 => DType::I16,
51 cubecl::ir::IntKind::I32 => DType::I32,
52 cubecl::ir::IntKind::I64 => DType::I64,
53 },
54 cubecl::ir::ElemType::UInt(uint_kind) => match uint_kind {
55 cubecl::ir::UIntKind::U8 => DType::U8,
56 cubecl::ir::UIntKind::U16 => DType::U16,
57 cubecl::ir::UIntKind::U32 => DType::U32,
58 cubecl::ir::UIntKind::U64 => DType::U64,
59 },
60 _ => panic!("Not a valid DType for tensors."),
61 }
62 }
63}
64
65impl DType {
66 pub const fn size(&self) -> usize {
68 match self {
69 DType::F64 => core::mem::size_of::<f64>(),
70 DType::F32 => core::mem::size_of::<f32>(),
71 DType::Flex32 => core::mem::size_of::<f32>(),
72 DType::F16 => core::mem::size_of::<f16>(),
73 DType::BF16 => core::mem::size_of::<bf16>(),
74 DType::I64 => core::mem::size_of::<i64>(),
75 DType::I32 => core::mem::size_of::<i32>(),
76 DType::I16 => core::mem::size_of::<i16>(),
77 DType::I8 => core::mem::size_of::<i8>(),
78 DType::U64 => core::mem::size_of::<u64>(),
79 DType::U32 => core::mem::size_of::<u32>(),
80 DType::U16 => core::mem::size_of::<u16>(),
81 DType::U8 => core::mem::size_of::<u8>(),
82 DType::Bool => core::mem::size_of::<bool>(),
83 DType::QFloat(scheme) => match scheme.store {
84 QuantStore::Native => match scheme.value {
85 QuantValue::Q8F | QuantValue::Q8S => core::mem::size_of::<i8>(),
86 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
89 core::mem::size_of::<u8>()
90 }
91 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
92 0
94 }
95 },
96 QuantStore::U32 => core::mem::size_of::<u32>(),
97 },
98 }
99 }
100 pub fn is_float(&self) -> bool {
102 matches!(
103 self,
104 DType::F64 | DType::F32 | DType::Flex32 | DType::F16 | DType::BF16
105 )
106 }
107 pub fn is_int(&self) -> bool {
109 matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
110 }
111 pub fn is_uint(&self) -> bool {
113 matches!(self, DType::U64 | DType::U32 | DType::U16 | DType::U8)
114 }
115
116 pub fn is_bool(&self) -> bool {
118 matches!(self, DType::Bool)
119 }
120
121 pub fn name(&self) -> &'static str {
123 match self {
124 DType::F64 => "f64",
125 DType::F32 => "f32",
126 DType::Flex32 => "flex32",
127 DType::F16 => "f16",
128 DType::BF16 => "bf16",
129 DType::I64 => "i64",
130 DType::I32 => "i32",
131 DType::I16 => "i16",
132 DType::I8 => "i8",
133 DType::U64 => "u64",
134 DType::U32 => "u32",
135 DType::U16 => "u16",
136 DType::U8 => "u8",
137 DType::Bool => "bool",
138 DType::QFloat(_) => "qfloat",
139 }
140 }
141}
142
143#[allow(missing_docs)]
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
145pub enum FloatDType {
146 F64,
147 F32,
148 Flex32,
149 F16,
150 BF16,
151}
152
153impl From<DType> for FloatDType {
154 fn from(value: DType) -> Self {
155 match value {
156 DType::F64 => FloatDType::F64,
157 DType::F32 => FloatDType::F32,
158 DType::Flex32 => FloatDType::Flex32,
159 DType::F16 => FloatDType::F16,
160 DType::BF16 => FloatDType::BF16,
161 _ => panic!("Expected float data type, got {value:?}"),
162 }
163 }
164}
165
166impl From<FloatDType> for DType {
167 fn from(value: FloatDType) -> Self {
168 match value {
169 FloatDType::F64 => DType::F64,
170 FloatDType::F32 => DType::F32,
171 FloatDType::Flex32 => DType::Flex32,
172 FloatDType::F16 => DType::F16,
173 FloatDType::BF16 => DType::BF16,
174 }
175 }
176}
177
178#[allow(missing_docs)]
179#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
180pub enum IntDType {
181 I64,
182 I32,
183 I16,
184 I8,
185 U64,
186 U32,
187 U16,
188 U8,
189}
190
191impl From<DType> for IntDType {
192 fn from(value: DType) -> Self {
193 match value {
194 DType::I64 => IntDType::I64,
195 DType::I32 => IntDType::I32,
196 DType::I16 => IntDType::I16,
197 DType::I8 => IntDType::I8,
198 DType::U64 => IntDType::U64,
199 DType::U32 => IntDType::U32,
200 DType::U16 => IntDType::U16,
201 DType::U8 => IntDType::U8,
202 _ => panic!("Expected int data type, got {value:?}"),
203 }
204 }
205}
206
207impl From<IntDType> for DType {
208 fn from(value: IntDType) -> Self {
209 match value {
210 IntDType::I64 => DType::I64,
211 IntDType::I32 => DType::I32,
212 IntDType::I16 => DType::I16,
213 IntDType::I8 => DType::I8,
214 IntDType::U64 => DType::U64,
215 IntDType::U32 => DType::U32,
216 IntDType::U16 => DType::U16,
217 IntDType::U8 => DType::U8,
218 }
219 }
220}