1use serde::{Deserialize, Serialize};
4
5use crate::tensor::quantization::{QuantScheme, QuantStore, QuantValue};
6use crate::{bf16, f16};
7
8#[allow(missing_docs)]
9#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
10pub enum DType {
11 F64,
12 F32,
13 Flex32,
14 F16,
15 BF16,
16 I64,
17 I32,
18 I16,
19 I8,
20 U64,
21 U32,
22 U16,
23 U8,
24 Bool,
25 QFloat(QuantScheme),
26}
27
28#[cfg(feature = "cubecl")]
29impl From<cubecl::ir::ElemType> for DType {
30 fn from(value: cubecl::ir::ElemType) -> Self {
31 match value {
32 cubecl::ir::ElemType::Float(float_kind) => match float_kind {
33 cubecl::ir::FloatKind::F16 => DType::F16,
34 cubecl::ir::FloatKind::BF16 => DType::BF16,
35 cubecl::ir::FloatKind::Flex32 => DType::Flex32,
36 cubecl::ir::FloatKind::F32 => DType::F32,
37 cubecl::ir::FloatKind::F64 => DType::F64,
38 cubecl::ir::FloatKind::TF32 => panic!("Not a valid DType for tensors."),
39 cubecl::ir::FloatKind::E2M1
40 | cubecl::ir::FloatKind::E2M3
41 | cubecl::ir::FloatKind::E3M2
42 | cubecl::ir::FloatKind::E4M3
43 | cubecl::ir::FloatKind::E5M2
44 | cubecl::ir::FloatKind::UE8M0 => {
45 unimplemented!("Not yet supported, will be used for quantization")
46 }
47 },
48 cubecl::ir::ElemType::Int(int_kind) => match int_kind {
49 cubecl::ir::IntKind::I8 => DType::I8,
50 cubecl::ir::IntKind::I16 => DType::I16,
51 cubecl::ir::IntKind::I32 => DType::I32,
52 cubecl::ir::IntKind::I64 => DType::I64,
53 },
54 cubecl::ir::ElemType::UInt(uint_kind) => match uint_kind {
55 cubecl::ir::UIntKind::U8 => DType::U8,
56 cubecl::ir::UIntKind::U16 => DType::U16,
57 cubecl::ir::UIntKind::U32 => DType::U32,
58 cubecl::ir::UIntKind::U64 => DType::U64,
59 },
60 _ => panic!("Not a valid DType for tensors."),
61 }
62 }
63}
64
65impl DType {
66 pub const fn size(&self) -> usize {
68 match self {
69 DType::F64 => core::mem::size_of::<f64>(),
70 DType::F32 => core::mem::size_of::<f32>(),
71 DType::Flex32 => core::mem::size_of::<f32>(),
72 DType::F16 => core::mem::size_of::<f16>(),
73 DType::BF16 => core::mem::size_of::<bf16>(),
74 DType::I64 => core::mem::size_of::<i64>(),
75 DType::I32 => core::mem::size_of::<i32>(),
76 DType::I16 => core::mem::size_of::<i16>(),
77 DType::I8 => core::mem::size_of::<i8>(),
78 DType::U64 => core::mem::size_of::<u64>(),
79 DType::U32 => core::mem::size_of::<u32>(),
80 DType::U16 => core::mem::size_of::<u16>(),
81 DType::U8 => core::mem::size_of::<u8>(),
82 DType::Bool => core::mem::size_of::<bool>(),
83 DType::QFloat(scheme) => match scheme.store {
84 QuantStore::Native => match scheme.value {
85 QuantValue::Q8F | QuantValue::Q8S => core::mem::size_of::<i8>(),
86 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
89 core::mem::size_of::<u8>()
90 }
91 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
92 0
94 }
95 },
96 QuantStore::PackedU32(_) => core::mem::size_of::<u32>(),
97 QuantStore::PackedNative(_) => match scheme.value {
98 QuantValue::E2M1 => core::mem::size_of::<u8>(),
99 _ => 0,
100 },
101 },
102 }
103 }
104 pub fn is_float(&self) -> bool {
106 matches!(
107 self,
108 DType::F64 | DType::F32 | DType::Flex32 | DType::F16 | DType::BF16
109 )
110 }
111 pub fn is_int(&self) -> bool {
113 matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
114 }
115 pub fn is_uint(&self) -> bool {
117 matches!(self, DType::U64 | DType::U32 | DType::U16 | DType::U8)
118 }
119
120 pub fn is_bool(&self) -> bool {
122 matches!(self, DType::Bool)
123 }
124
125 pub fn name(&self) -> &'static str {
127 match self {
128 DType::F64 => "f64",
129 DType::F32 => "f32",
130 DType::Flex32 => "flex32",
131 DType::F16 => "f16",
132 DType::BF16 => "bf16",
133 DType::I64 => "i64",
134 DType::I32 => "i32",
135 DType::I16 => "i16",
136 DType::I8 => "i8",
137 DType::U64 => "u64",
138 DType::U32 => "u32",
139 DType::U16 => "u16",
140 DType::U8 => "u8",
141 DType::Bool => "bool",
142 DType::QFloat(_) => "qfloat",
143 }
144 }
145}
146
147#[allow(missing_docs)]
148#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
149pub enum FloatDType {
150 F64,
151 F32,
152 Flex32,
153 F16,
154 BF16,
155}
156
157impl From<DType> for FloatDType {
158 fn from(value: DType) -> Self {
159 match value {
160 DType::F64 => FloatDType::F64,
161 DType::F32 => FloatDType::F32,
162 DType::Flex32 => FloatDType::Flex32,
163 DType::F16 => FloatDType::F16,
164 DType::BF16 => FloatDType::BF16,
165 _ => panic!("Expected float data type, got {value:?}"),
166 }
167 }
168}
169
170impl From<FloatDType> for DType {
171 fn from(value: FloatDType) -> Self {
172 match value {
173 FloatDType::F64 => DType::F64,
174 FloatDType::F32 => DType::F32,
175 FloatDType::Flex32 => DType::Flex32,
176 FloatDType::F16 => DType::F16,
177 FloatDType::BF16 => DType::BF16,
178 }
179 }
180}
181
182#[allow(missing_docs)]
183#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
184pub enum IntDType {
185 I64,
186 I32,
187 I16,
188 I8,
189 U64,
190 U32,
191 U16,
192 U8,
193}
194
195impl From<DType> for IntDType {
196 fn from(value: DType) -> Self {
197 match value {
198 DType::I64 => IntDType::I64,
199 DType::I32 => IntDType::I32,
200 DType::I16 => IntDType::I16,
201 DType::I8 => IntDType::I8,
202 DType::U64 => IntDType::U64,
203 DType::U32 => IntDType::U32,
204 DType::U16 => IntDType::U16,
205 DType::U8 => IntDType::U8,
206 _ => panic!("Expected int data type, got {value:?}"),
207 }
208 }
209}
210
211impl From<IntDType> for DType {
212 fn from(value: IntDType) -> Self {
213 match value {
214 IntDType::I64 => DType::I64,
215 IntDType::I32 => DType::I32,
216 IntDType::I16 => DType::I16,
217 IntDType::I8 => DType::I8,
218 IntDType::U64 => DType::U64,
219 IntDType::U32 => DType::U32,
220 IntDType::U16 => DType::U16,
221 IntDType::U8 => DType::U8,
222 }
223 }
224}