burn_std/tensor/
dtype.rs

1//! Tensor data type.
2
3use serde::{Deserialize, Serialize};
4
5use crate::tensor::quantization::{QuantScheme, QuantStore, QuantValue};
6use crate::{bf16, f16};
7
8#[allow(missing_docs)]
9#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
10pub enum DType {
11    F64,
12    F32,
13    Flex32,
14    F16,
15    BF16,
16    I64,
17    I32,
18    I16,
19    I8,
20    U64,
21    U32,
22    U16,
23    U8,
24    Bool,
25    QFloat(QuantScheme),
26}
27
28#[cfg(feature = "cubecl")]
29impl From<cubecl::ir::ElemType> for DType {
30    fn from(value: cubecl::ir::ElemType) -> Self {
31        match value {
32            cubecl::ir::ElemType::Float(float_kind) => match float_kind {
33                cubecl::ir::FloatKind::F16 => DType::F16,
34                cubecl::ir::FloatKind::BF16 => DType::BF16,
35                cubecl::ir::FloatKind::Flex32 => DType::Flex32,
36                cubecl::ir::FloatKind::F32 => DType::F32,
37                cubecl::ir::FloatKind::F64 => DType::F64,
38                cubecl::ir::FloatKind::TF32 => panic!("Not a valid DType for tensors."),
39                cubecl::ir::FloatKind::E2M1
40                | cubecl::ir::FloatKind::E2M3
41                | cubecl::ir::FloatKind::E3M2
42                | cubecl::ir::FloatKind::E4M3
43                | cubecl::ir::FloatKind::E5M2
44                | cubecl::ir::FloatKind::UE8M0 => {
45                    unimplemented!("Not yet supported, will be used for quantization")
46                }
47            },
48            cubecl::ir::ElemType::Int(int_kind) => match int_kind {
49                cubecl::ir::IntKind::I8 => DType::I8,
50                cubecl::ir::IntKind::I16 => DType::I16,
51                cubecl::ir::IntKind::I32 => DType::I32,
52                cubecl::ir::IntKind::I64 => DType::I64,
53            },
54            cubecl::ir::ElemType::UInt(uint_kind) => match uint_kind {
55                cubecl::ir::UIntKind::U8 => DType::U8,
56                cubecl::ir::UIntKind::U16 => DType::U16,
57                cubecl::ir::UIntKind::U32 => DType::U32,
58                cubecl::ir::UIntKind::U64 => DType::U64,
59            },
60            _ => panic!("Not a valid DType for tensors."),
61        }
62    }
63}
64
65impl DType {
66    /// Returns the size of a type in bytes.
67    pub const fn size(&self) -> usize {
68        match self {
69            DType::F64 => core::mem::size_of::<f64>(),
70            DType::F32 => core::mem::size_of::<f32>(),
71            DType::Flex32 => core::mem::size_of::<f32>(),
72            DType::F16 => core::mem::size_of::<f16>(),
73            DType::BF16 => core::mem::size_of::<bf16>(),
74            DType::I64 => core::mem::size_of::<i64>(),
75            DType::I32 => core::mem::size_of::<i32>(),
76            DType::I16 => core::mem::size_of::<i16>(),
77            DType::I8 => core::mem::size_of::<i8>(),
78            DType::U64 => core::mem::size_of::<u64>(),
79            DType::U32 => core::mem::size_of::<u32>(),
80            DType::U16 => core::mem::size_of::<u16>(),
81            DType::U8 => core::mem::size_of::<u8>(),
82            DType::Bool => core::mem::size_of::<bool>(),
83            DType::QFloat(scheme) => match scheme.store {
84                QuantStore::Native => match scheme.value {
85                    QuantValue::Q8F | QuantValue::Q8S => core::mem::size_of::<i8>(),
86                    // e2m1 native is automatically packed by the kernels, so the actual storage is
87                    // 8 bits wide.
88                    QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
89                        core::mem::size_of::<u8>()
90                    }
91                    QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
92                        // Sub-byte values have fractional size
93                        0
94                    }
95                },
96                QuantStore::PackedU32(_) => core::mem::size_of::<u32>(),
97                QuantStore::PackedNative(_) => match scheme.value {
98                    QuantValue::E2M1 => core::mem::size_of::<u8>(),
99                    _ => 0,
100                },
101            },
102        }
103    }
104    /// Returns true if the data type is a floating point type.
105    pub fn is_float(&self) -> bool {
106        matches!(
107            self,
108            DType::F64 | DType::F32 | DType::Flex32 | DType::F16 | DType::BF16
109        )
110    }
111    /// Returns true if the data type is a signed integer type.
112    pub fn is_int(&self) -> bool {
113        matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
114    }
115    /// Returns true if the data type is an unsigned integer type.
116    pub fn is_uint(&self) -> bool {
117        matches!(self, DType::U64 | DType::U32 | DType::U16 | DType::U8)
118    }
119
120    /// Returns true if the data type is a boolean type
121    pub fn is_bool(&self) -> bool {
122        matches!(self, DType::Bool)
123    }
124
125    /// Returns the data type name.
126    pub fn name(&self) -> &'static str {
127        match self {
128            DType::F64 => "f64",
129            DType::F32 => "f32",
130            DType::Flex32 => "flex32",
131            DType::F16 => "f16",
132            DType::BF16 => "bf16",
133            DType::I64 => "i64",
134            DType::I32 => "i32",
135            DType::I16 => "i16",
136            DType::I8 => "i8",
137            DType::U64 => "u64",
138            DType::U32 => "u32",
139            DType::U16 => "u16",
140            DType::U8 => "u8",
141            DType::Bool => "bool",
142            DType::QFloat(_) => "qfloat",
143        }
144    }
145}
146
147#[allow(missing_docs)]
148#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
149pub enum FloatDType {
150    F64,
151    F32,
152    Flex32,
153    F16,
154    BF16,
155}
156
157impl From<DType> for FloatDType {
158    fn from(value: DType) -> Self {
159        match value {
160            DType::F64 => FloatDType::F64,
161            DType::F32 => FloatDType::F32,
162            DType::Flex32 => FloatDType::Flex32,
163            DType::F16 => FloatDType::F16,
164            DType::BF16 => FloatDType::BF16,
165            _ => panic!("Expected float data type, got {value:?}"),
166        }
167    }
168}
169
170impl From<FloatDType> for DType {
171    fn from(value: FloatDType) -> Self {
172        match value {
173            FloatDType::F64 => DType::F64,
174            FloatDType::F32 => DType::F32,
175            FloatDType::Flex32 => DType::Flex32,
176            FloatDType::F16 => DType::F16,
177            FloatDType::BF16 => DType::BF16,
178        }
179    }
180}
181
182#[allow(missing_docs)]
183#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
184pub enum IntDType {
185    I64,
186    I32,
187    I16,
188    I8,
189    U64,
190    U32,
191    U16,
192    U8,
193}
194
195impl From<DType> for IntDType {
196    fn from(value: DType) -> Self {
197        match value {
198            DType::I64 => IntDType::I64,
199            DType::I32 => IntDType::I32,
200            DType::I16 => IntDType::I16,
201            DType::I8 => IntDType::I8,
202            DType::U64 => IntDType::U64,
203            DType::U32 => IntDType::U32,
204            DType::U16 => IntDType::U16,
205            DType::U8 => IntDType::U8,
206            _ => panic!("Expected int data type, got {value:?}"),
207        }
208    }
209}
210
211impl From<IntDType> for DType {
212    fn from(value: IntDType) -> Self {
213        match value {
214            IntDType::I64 => DType::I64,
215            IntDType::I32 => DType::I32,
216            IntDType::I16 => DType::I16,
217            IntDType::I8 => DType::I8,
218            IntDType::U64 => DType::U64,
219            IntDType::U32 => DType::U32,
220            IntDType::U16 => DType::U16,
221            IntDType::U8 => DType::U8,
222        }
223    }
224}