1use serde::{Deserialize, Serialize};
4
5use crate::tensor::quantization::{QuantScheme, QuantStore, QuantValue};
6use crate::{bf16, f16};
7
8#[allow(missing_docs)]
9#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
10pub enum DType {
11 F64,
12 F32,
13 Flex32,
14 F16,
15 BF16,
16 I64,
17 I32,
18 I16,
19 I8,
20 U64,
21 U32,
22 U16,
23 U8,
24 Bool(BoolStore),
25 QFloat(QuantScheme),
26}
27
28#[cfg(feature = "cubecl")]
29impl From<cubecl::ir::ElemType> for DType {
30 fn from(value: cubecl::ir::ElemType) -> Self {
31 match value {
32 cubecl::ir::ElemType::Float(float_kind) => match float_kind {
33 cubecl::ir::FloatKind::F16 => DType::F16,
34 cubecl::ir::FloatKind::BF16 => DType::BF16,
35 cubecl::ir::FloatKind::Flex32 => DType::Flex32,
36 cubecl::ir::FloatKind::F32 => DType::F32,
37 cubecl::ir::FloatKind::F64 => DType::F64,
38 cubecl::ir::FloatKind::TF32 => panic!("Not a valid DType for tensors."),
39 cubecl::ir::FloatKind::E2M1
40 | cubecl::ir::FloatKind::E2M3
41 | cubecl::ir::FloatKind::E3M2
42 | cubecl::ir::FloatKind::E4M3
43 | cubecl::ir::FloatKind::E5M2
44 | cubecl::ir::FloatKind::UE8M0 => {
45 unimplemented!("Not yet supported, will be used for quantization")
46 }
47 },
48 cubecl::ir::ElemType::Int(int_kind) => match int_kind {
49 cubecl::ir::IntKind::I8 => DType::I8,
50 cubecl::ir::IntKind::I16 => DType::I16,
51 cubecl::ir::IntKind::I32 => DType::I32,
52 cubecl::ir::IntKind::I64 => DType::I64,
53 },
54 cubecl::ir::ElemType::UInt(uint_kind) => match uint_kind {
55 cubecl::ir::UIntKind::U8 => DType::U8,
56 cubecl::ir::UIntKind::U16 => DType::U16,
57 cubecl::ir::UIntKind::U32 => DType::U32,
58 cubecl::ir::UIntKind::U64 => DType::U64,
59 },
60 _ => panic!("Not a valid DType for tensors."),
61 }
62 }
63}
64
65impl DType {
66 pub const fn size(&self) -> usize {
68 match self {
69 DType::F64 => core::mem::size_of::<f64>(),
70 DType::F32 => core::mem::size_of::<f32>(),
71 DType::Flex32 => core::mem::size_of::<f32>(),
72 DType::F16 => core::mem::size_of::<f16>(),
73 DType::BF16 => core::mem::size_of::<bf16>(),
74 DType::I64 => core::mem::size_of::<i64>(),
75 DType::I32 => core::mem::size_of::<i32>(),
76 DType::I16 => core::mem::size_of::<i16>(),
77 DType::I8 => core::mem::size_of::<i8>(),
78 DType::U64 => core::mem::size_of::<u64>(),
79 DType::U32 => core::mem::size_of::<u32>(),
80 DType::U16 => core::mem::size_of::<u16>(),
81 DType::U8 => core::mem::size_of::<u8>(),
82 DType::Bool(store) => match store {
83 BoolStore::Native => core::mem::size_of::<bool>(),
84 BoolStore::U8 => core::mem::size_of::<u8>(),
85 BoolStore::U32 => core::mem::size_of::<u32>(),
86 },
87 DType::QFloat(scheme) => match scheme.store {
88 QuantStore::Native => match scheme.value {
89 QuantValue::Q8F | QuantValue::Q8S => core::mem::size_of::<i8>(),
90 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
93 core::mem::size_of::<u8>()
94 }
95 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
96 0
98 }
99 },
100 QuantStore::PackedU32(_) => core::mem::size_of::<u32>(),
101 QuantStore::PackedNative(_) => match scheme.value {
102 QuantValue::E2M1 => core::mem::size_of::<u8>(),
103 _ => 0,
104 },
105 },
106 }
107 }
108 pub fn is_float(&self) -> bool {
110 matches!(
111 self,
112 DType::F64 | DType::F32 | DType::Flex32 | DType::F16 | DType::BF16
113 )
114 }
115 pub fn is_int(&self) -> bool {
117 matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
118 }
119 pub fn is_uint(&self) -> bool {
121 matches!(self, DType::U64 | DType::U32 | DType::U16 | DType::U8)
122 }
123
124 pub fn is_bool(&self) -> bool {
126 matches!(self, DType::Bool(_))
127 }
128
129 pub fn name(&self) -> &'static str {
131 match self {
132 DType::F64 => "f64",
133 DType::F32 => "f32",
134 DType::Flex32 => "flex32",
135 DType::F16 => "f16",
136 DType::BF16 => "bf16",
137 DType::I64 => "i64",
138 DType::I32 => "i32",
139 DType::I16 => "i16",
140 DType::I8 => "i8",
141 DType::U64 => "u64",
142 DType::U32 => "u32",
143 DType::U16 => "u16",
144 DType::U8 => "u8",
145 DType::Bool(store) => match store {
146 BoolStore::Native => "bool",
147 BoolStore::U8 => "bool(u8)",
148 BoolStore::U32 => "bool(u32)",
149 },
150 DType::QFloat(_) => "qfloat",
151 }
152 }
153}
154
155#[allow(missing_docs)]
156#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
157pub enum FloatDType {
158 F64,
159 F32,
160 Flex32,
161 F16,
162 BF16,
163}
164
165impl From<DType> for FloatDType {
166 fn from(value: DType) -> Self {
167 match value {
168 DType::F64 => FloatDType::F64,
169 DType::F32 => FloatDType::F32,
170 DType::Flex32 => FloatDType::Flex32,
171 DType::F16 => FloatDType::F16,
172 DType::BF16 => FloatDType::BF16,
173 _ => panic!("Expected float data type, got {value:?}"),
174 }
175 }
176}
177
178impl From<FloatDType> for DType {
179 fn from(value: FloatDType) -> Self {
180 match value {
181 FloatDType::F64 => DType::F64,
182 FloatDType::F32 => DType::F32,
183 FloatDType::Flex32 => DType::Flex32,
184 FloatDType::F16 => DType::F16,
185 FloatDType::BF16 => DType::BF16,
186 }
187 }
188}
189
190#[allow(missing_docs)]
191#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
192pub enum IntDType {
193 I64,
194 I32,
195 I16,
196 I8,
197 U64,
198 U32,
199 U16,
200 U8,
201}
202
203impl From<DType> for IntDType {
204 fn from(value: DType) -> Self {
205 match value {
206 DType::I64 => IntDType::I64,
207 DType::I32 => IntDType::I32,
208 DType::I16 => IntDType::I16,
209 DType::I8 => IntDType::I8,
210 DType::U64 => IntDType::U64,
211 DType::U32 => IntDType::U32,
212 DType::U16 => IntDType::U16,
213 DType::U8 => IntDType::U8,
214 _ => panic!("Expected int data type, got {value:?}"),
215 }
216 }
217}
218
219impl From<IntDType> for DType {
220 fn from(value: IntDType) -> Self {
221 match value {
222 IntDType::I64 => DType::I64,
223 IntDType::I32 => DType::I32,
224 IntDType::I16 => DType::I16,
225 IntDType::I8 => DType::I8,
226 IntDType::U64 => DType::U64,
227 IntDType::U32 => DType::U32,
228 IntDType::U16 => DType::U16,
229 IntDType::U8 => DType::U8,
230 }
231 }
232}
233
234#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
235pub enum BoolStore {
237 Native,
239 U8,
241 U32,
243}
244
245pub type BoolDType = BoolStore;
249
250#[allow(deprecated)]
251impl From<DType> for BoolDType {
252 fn from(value: DType) -> Self {
253 match value {
254 DType::Bool(store) => match store {
255 BoolStore::Native => BoolDType::Native,
256 BoolStore::U8 => BoolDType::U8,
257 BoolStore::U32 => BoolDType::U32,
258 },
259 DType::U8 => BoolDType::U8,
261 DType::U32 => BoolDType::U32,
262 _ => panic!("Expected bool data type, got {value:?}"),
263 }
264 }
265}
266
267impl From<BoolDType> for DType {
268 fn from(value: BoolDType) -> Self {
269 match value {
270 BoolDType::Native => DType::Bool(BoolStore::Native),
271 BoolDType::U8 => DType::Bool(BoolStore::U8),
272 BoolDType::U32 => DType::Bool(BoolStore::U32),
273 }
274 }
275}