1use serde::{Deserialize, Serialize};
4
5use crate::tensor::quantization::{QuantScheme, QuantStore, QuantValue};
6use crate::{bf16, f16};
7
8#[allow(missing_docs)]
9#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
10pub enum DType {
11 F64,
12 F32,
13 Flex32,
14 F16,
15 BF16,
16 I64,
17 I32,
18 I16,
19 I8,
20 U64,
21 U32,
22 U16,
23 U8,
24 Bool(BoolStore),
25 QFloat(QuantScheme),
26}
27
28#[cfg(feature = "cubecl")]
29impl From<cubecl::ir::ElemType> for DType {
30 fn from(value: cubecl::ir::ElemType) -> Self {
31 match value {
32 cubecl::ir::ElemType::Float(float_kind) => match float_kind {
33 cubecl::ir::FloatKind::F16 => DType::F16,
34 cubecl::ir::FloatKind::BF16 => DType::BF16,
35 cubecl::ir::FloatKind::Flex32 => DType::Flex32,
36 cubecl::ir::FloatKind::F32 => DType::F32,
37 cubecl::ir::FloatKind::F64 => DType::F64,
38 cubecl::ir::FloatKind::TF32 => panic!("Not a valid DType for tensors."),
39 cubecl::ir::FloatKind::E2M1
40 | cubecl::ir::FloatKind::E2M3
41 | cubecl::ir::FloatKind::E3M2
42 | cubecl::ir::FloatKind::E4M3
43 | cubecl::ir::FloatKind::E5M2
44 | cubecl::ir::FloatKind::UE8M0 => {
45 unimplemented!("Not yet supported, will be used for quantization")
46 }
47 },
48 cubecl::ir::ElemType::Int(int_kind) => match int_kind {
49 cubecl::ir::IntKind::I8 => DType::I8,
50 cubecl::ir::IntKind::I16 => DType::I16,
51 cubecl::ir::IntKind::I32 => DType::I32,
52 cubecl::ir::IntKind::I64 => DType::I64,
53 },
54 cubecl::ir::ElemType::UInt(uint_kind) => match uint_kind {
55 cubecl::ir::UIntKind::U8 => DType::U8,
56 cubecl::ir::UIntKind::U16 => DType::U16,
57 cubecl::ir::UIntKind::U32 => DType::U32,
58 cubecl::ir::UIntKind::U64 => DType::U64,
59 },
60 _ => panic!("Not a valid DType for tensors."),
61 }
62 }
63}
64
65impl DType {
66 pub const fn size(&self) -> usize {
68 match self {
69 DType::F64 => core::mem::size_of::<f64>(),
70 DType::F32 => core::mem::size_of::<f32>(),
71 DType::Flex32 => core::mem::size_of::<f32>(),
72 DType::F16 => core::mem::size_of::<f16>(),
73 DType::BF16 => core::mem::size_of::<bf16>(),
74 DType::I64 => core::mem::size_of::<i64>(),
75 DType::I32 => core::mem::size_of::<i32>(),
76 DType::I16 => core::mem::size_of::<i16>(),
77 DType::I8 => core::mem::size_of::<i8>(),
78 DType::U64 => core::mem::size_of::<u64>(),
79 DType::U32 => core::mem::size_of::<u32>(),
80 DType::U16 => core::mem::size_of::<u16>(),
81 DType::U8 => core::mem::size_of::<u8>(),
82 DType::Bool(store) => match store {
83 BoolStore::Native => core::mem::size_of::<bool>(),
84 BoolStore::U8 => core::mem::size_of::<u8>(),
85 BoolStore::U32 => core::mem::size_of::<u32>(),
86 },
87 DType::QFloat(scheme) => match scheme.store {
88 QuantStore::Native => match scheme.value {
89 QuantValue::Q8F | QuantValue::Q8S => core::mem::size_of::<i8>(),
90 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
93 core::mem::size_of::<u8>()
94 }
95 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
96 0
98 }
99 },
100 QuantStore::PackedU32(_) => core::mem::size_of::<u32>(),
101 QuantStore::PackedNative(_) => match scheme.value {
102 QuantValue::E2M1 => core::mem::size_of::<u8>(),
103 _ => 0,
104 },
105 },
106 }
107 }
108 pub fn is_float(&self) -> bool {
110 matches!(
111 self,
112 DType::F64 | DType::F32 | DType::Flex32 | DType::F16 | DType::BF16
113 )
114 }
115 pub fn is_int(&self) -> bool {
117 matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
118 }
119 pub fn is_uint(&self) -> bool {
121 matches!(self, DType::U64 | DType::U32 | DType::U16 | DType::U8)
122 }
123
124 pub fn is_bool(&self) -> bool {
126 matches!(self, DType::Bool(_))
127 }
128
129 pub const fn finfo(&self) -> Option<FloatInfo> {
133 match self {
134 DType::F64 => Some(FloatDType::F64.finfo()),
135 DType::F32 => Some(FloatDType::F32.finfo()),
136 DType::Flex32 => Some(FloatDType::Flex32.finfo()),
137 DType::F16 => Some(FloatDType::F16.finfo()),
138 DType::BF16 => Some(FloatDType::BF16.finfo()),
139 _ => None,
140 }
141 }
142
143 pub fn name(&self) -> &'static str {
145 match self {
146 DType::F64 => "f64",
147 DType::F32 => "f32",
148 DType::Flex32 => "flex32",
149 DType::F16 => "f16",
150 DType::BF16 => "bf16",
151 DType::I64 => "i64",
152 DType::I32 => "i32",
153 DType::I16 => "i16",
154 DType::I8 => "i8",
155 DType::U64 => "u64",
156 DType::U32 => "u32",
157 DType::U16 => "u16",
158 DType::U8 => "u8",
159 DType::Bool(store) => match store {
160 BoolStore::Native => "bool",
161 BoolStore::U8 => "bool(u8)",
162 BoolStore::U32 => "bool(u32)",
163 },
164 DType::QFloat(_) => "qfloat",
165 }
166 }
167}
168
169#[allow(missing_docs)]
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
171pub enum FloatDType {
172 F64,
173 F32,
174 Flex32,
175 F16,
176 BF16,
177}
178
179#[derive(Debug, Clone, Copy, PartialEq)]
185pub struct FloatInfo {
186 pub epsilon: f64,
188 pub max: f64,
190 pub min: f64,
192 pub min_positive: f64,
194}
195
196impl FloatDType {
197 pub const fn finfo(self) -> FloatInfo {
201 match self {
202 FloatDType::F64 => FloatInfo {
203 epsilon: f64::EPSILON,
204 max: f64::MAX,
205 min: f64::MIN,
206 min_positive: f64::MIN_POSITIVE, },
208 FloatDType::F32 => FloatInfo {
209 epsilon: f32::EPSILON as f64,
210 max: f32::MAX as f64,
211 min: f32::MIN as f64,
212 min_positive: f32::MIN_POSITIVE as f64, },
214 FloatDType::Flex32 => FloatInfo {
217 epsilon: f16::EPSILON.to_f64_const(),
218 max: f16::MAX.to_f64_const(),
219 min: f16::MIN.to_f64_const(),
220 min_positive: f16::MIN_POSITIVE.to_f64_const(), },
222 FloatDType::F16 => FloatInfo {
223 epsilon: f16::EPSILON.to_f64_const(),
224 max: f16::MAX.to_f64_const(),
225 min: f16::MIN.to_f64_const(),
226 min_positive: f16::MIN_POSITIVE.to_f64_const(), },
228 FloatDType::BF16 => FloatInfo {
229 epsilon: bf16::EPSILON.to_f64_const(),
230 max: bf16::MAX.to_f64_const(),
231 min: bf16::MIN.to_f64_const(),
232 min_positive: bf16::MIN_POSITIVE.to_f64_const(), },
234 }
235 }
236}
237
238impl From<DType> for FloatDType {
239 fn from(value: DType) -> Self {
240 match value {
241 DType::F64 => FloatDType::F64,
242 DType::F32 => FloatDType::F32,
243 DType::Flex32 => FloatDType::Flex32,
244 DType::F16 => FloatDType::F16,
245 DType::BF16 => FloatDType::BF16,
246 _ => panic!("Expected float data type, got {value:?}"),
247 }
248 }
249}
250
251impl From<FloatDType> for DType {
252 fn from(value: FloatDType) -> Self {
253 match value {
254 FloatDType::F64 => DType::F64,
255 FloatDType::F32 => DType::F32,
256 FloatDType::Flex32 => DType::Flex32,
257 FloatDType::F16 => DType::F16,
258 FloatDType::BF16 => DType::BF16,
259 }
260 }
261}
262
263#[allow(missing_docs)]
264#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
265pub enum IntDType {
266 I64,
267 I32,
268 I16,
269 I8,
270 U64,
271 U32,
272 U16,
273 U8,
274}
275
276impl From<DType> for IntDType {
277 fn from(value: DType) -> Self {
278 match value {
279 DType::I64 => IntDType::I64,
280 DType::I32 => IntDType::I32,
281 DType::I16 => IntDType::I16,
282 DType::I8 => IntDType::I8,
283 DType::U64 => IntDType::U64,
284 DType::U32 => IntDType::U32,
285 DType::U16 => IntDType::U16,
286 DType::U8 => IntDType::U8,
287 _ => panic!("Expected int data type, got {value:?}"),
288 }
289 }
290}
291
292impl From<IntDType> for DType {
293 fn from(value: IntDType) -> Self {
294 match value {
295 IntDType::I64 => DType::I64,
296 IntDType::I32 => DType::I32,
297 IntDType::I16 => DType::I16,
298 IntDType::I8 => DType::I8,
299 IntDType::U64 => DType::U64,
300 IntDType::U32 => DType::U32,
301 IntDType::U16 => DType::U16,
302 IntDType::U8 => DType::U8,
303 }
304 }
305}
306
307#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
308pub enum BoolStore {
310 Native,
312 U8,
314 U32,
316}
317
318pub type BoolDType = BoolStore;
322
323#[allow(deprecated)]
324impl From<DType> for BoolDType {
325 fn from(value: DType) -> Self {
326 match value {
327 DType::Bool(store) => match store {
328 BoolStore::Native => BoolDType::Native,
329 BoolStore::U8 => BoolDType::U8,
330 BoolStore::U32 => BoolDType::U32,
331 },
332 DType::U8 => BoolDType::U8,
334 DType::U32 => BoolDType::U32,
335 _ => panic!("Expected bool data type, got {value:?}"),
336 }
337 }
338}
339
340impl From<BoolDType> for DType {
341 fn from(value: BoolDType) -> Self {
342 match value {
343 BoolDType::Native => DType::Bool(BoolStore::Native),
344 BoolDType::U8 => DType::Bool(BoolStore::U8),
345 BoolDType::U32 => DType::Bool(BoolStore::U32),
346 }
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn finfo_f32() {
356 let info = FloatDType::F32.finfo();
357 assert_eq!(info.epsilon, f32::EPSILON as f64);
358 assert_eq!(info.max, f32::MAX as f64);
359 assert_eq!(info.min, f32::MIN as f64);
360 assert_eq!(info.min_positive, f32::MIN_POSITIVE as f64);
361 }
362
363 #[test]
364 fn finfo_f64() {
365 let info = FloatDType::F64.finfo();
366 assert_eq!(info.epsilon, f64::EPSILON);
367 assert_eq!(info.max, f64::MAX);
368 assert_eq!(info.min, f64::MIN);
369 assert_eq!(info.min_positive, f64::MIN_POSITIVE);
370 }
371
372 #[test]
373 fn finfo_f16() {
374 let info = FloatDType::F16.finfo();
375 assert_eq!(info.epsilon, f16::EPSILON.to_f64_const());
376 assert!(info.epsilon > 0.0);
377 assert!(info.min_positive > 0.0);
378 assert!(info.epsilon > FloatDType::F32.finfo().epsilon);
380 }
381
382 #[test]
383 fn finfo_bf16() {
384 let info = FloatDType::BF16.finfo();
385 assert_eq!(info.epsilon, bf16::EPSILON.to_f64_const());
386 assert!(info.epsilon > 0.0);
387 assert!(info.min_positive > 0.0);
388 assert!(info.epsilon > FloatDType::F32.finfo().epsilon);
390 }
391
392 #[test]
393 fn finfo_flex32_uses_f16_limits() {
394 let flex = FloatDType::Flex32.finfo();
395 let f16_info = FloatDType::F16.finfo();
396 assert_eq!(flex.epsilon, f16_info.epsilon);
397 assert_eq!(flex.min_positive, f16_info.min_positive);
398 }
399
400 #[test]
401 fn dtype_finfo_delegates_to_float_dtype() {
402 assert_eq!(DType::F32.finfo(), Some(FloatDType::F32.finfo()));
403 assert_eq!(DType::F64.finfo(), Some(FloatDType::F64.finfo()));
404 assert_eq!(DType::F16.finfo(), Some(FloatDType::F16.finfo()));
405 assert_eq!(DType::BF16.finfo(), Some(FloatDType::BF16.finfo()));
406 assert_eq!(DType::Flex32.finfo(), Some(FloatDType::Flex32.finfo()));
407 }
408
409 #[test]
410 fn dtype_finfo_returns_none_for_non_float() {
411 assert!(DType::I32.finfo().is_none());
412 assert!(DType::U8.finfo().is_none());
413 assert!(DType::Bool(BoolStore::Native).finfo().is_none());
414 }
415
416 #[test]
417 fn finfo_invariants() {
418 for dtype in [
419 FloatDType::F64,
420 FloatDType::F32,
421 FloatDType::F16,
422 FloatDType::BF16,
423 FloatDType::Flex32,
424 ] {
425 let info = dtype.finfo();
426 assert!(info.epsilon > 0.0, "{dtype:?}: epsilon must be positive");
427 assert!(
428 info.min_positive > 0.0,
429 "{dtype:?}: min_positive must be positive"
430 );
431 assert!(info.max > 0.0, "{dtype:?}: max must be positive");
432 assert!(info.min < 0.0, "{dtype:?}: min must be negative");
433 assert!(
434 info.max > info.min_positive,
435 "{dtype:?}: max > min_positive"
436 );
437 }
438 }
439}