Skip to main content

burn_std/tensor/
dtype.rs

1//! Tensor data type.
2
3use serde::{Deserialize, Serialize};
4
5use crate::tensor::quantization::{QuantScheme, QuantStore, QuantValue};
6use crate::{bf16, f16};
7
8#[allow(missing_docs)]
9#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
10pub enum DType {
11    F64,
12    F32,
13    Flex32,
14    F16,
15    BF16,
16    I64,
17    I32,
18    I16,
19    I8,
20    U64,
21    U32,
22    U16,
23    U8,
24    Bool(BoolStore),
25    QFloat(QuantScheme),
26}
27
28#[cfg(feature = "cubecl")]
29impl From<cubecl::ir::ElemType> for DType {
30    fn from(value: cubecl::ir::ElemType) -> Self {
31        match value {
32            cubecl::ir::ElemType::Float(float_kind) => match float_kind {
33                cubecl::ir::FloatKind::F16 => DType::F16,
34                cubecl::ir::FloatKind::BF16 => DType::BF16,
35                cubecl::ir::FloatKind::Flex32 => DType::Flex32,
36                cubecl::ir::FloatKind::F32 => DType::F32,
37                cubecl::ir::FloatKind::F64 => DType::F64,
38                cubecl::ir::FloatKind::TF32 => panic!("Not a valid DType for tensors."),
39                cubecl::ir::FloatKind::E2M1
40                | cubecl::ir::FloatKind::E2M3
41                | cubecl::ir::FloatKind::E3M2
42                | cubecl::ir::FloatKind::E4M3
43                | cubecl::ir::FloatKind::E5M2
44                | cubecl::ir::FloatKind::UE8M0 => {
45                    unimplemented!("Not yet supported, will be used for quantization")
46                }
47            },
48            cubecl::ir::ElemType::Int(int_kind) => match int_kind {
49                cubecl::ir::IntKind::I8 => DType::I8,
50                cubecl::ir::IntKind::I16 => DType::I16,
51                cubecl::ir::IntKind::I32 => DType::I32,
52                cubecl::ir::IntKind::I64 => DType::I64,
53            },
54            cubecl::ir::ElemType::UInt(uint_kind) => match uint_kind {
55                cubecl::ir::UIntKind::U8 => DType::U8,
56                cubecl::ir::UIntKind::U16 => DType::U16,
57                cubecl::ir::UIntKind::U32 => DType::U32,
58                cubecl::ir::UIntKind::U64 => DType::U64,
59            },
60            _ => panic!("Not a valid DType for tensors."),
61        }
62    }
63}
64
65impl DType {
66    /// Returns the size of a type in bytes.
67    pub const fn size(&self) -> usize {
68        match self {
69            DType::F64 => core::mem::size_of::<f64>(),
70            DType::F32 => core::mem::size_of::<f32>(),
71            DType::Flex32 => core::mem::size_of::<f32>(),
72            DType::F16 => core::mem::size_of::<f16>(),
73            DType::BF16 => core::mem::size_of::<bf16>(),
74            DType::I64 => core::mem::size_of::<i64>(),
75            DType::I32 => core::mem::size_of::<i32>(),
76            DType::I16 => core::mem::size_of::<i16>(),
77            DType::I8 => core::mem::size_of::<i8>(),
78            DType::U64 => core::mem::size_of::<u64>(),
79            DType::U32 => core::mem::size_of::<u32>(),
80            DType::U16 => core::mem::size_of::<u16>(),
81            DType::U8 => core::mem::size_of::<u8>(),
82            DType::Bool(store) => match store {
83                BoolStore::Native => core::mem::size_of::<bool>(),
84                BoolStore::U8 => core::mem::size_of::<u8>(),
85                BoolStore::U32 => core::mem::size_of::<u32>(),
86            },
87            DType::QFloat(scheme) => match scheme.store {
88                QuantStore::Native => match scheme.value {
89                    QuantValue::Q8F | QuantValue::Q8S => core::mem::size_of::<i8>(),
90                    // e2m1 native is automatically packed by the kernels, so the actual storage is
91                    // 8 bits wide.
92                    QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
93                        core::mem::size_of::<u8>()
94                    }
95                    QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
96                        // Sub-byte values have fractional size
97                        0
98                    }
99                },
100                QuantStore::PackedU32(_) => core::mem::size_of::<u32>(),
101                QuantStore::PackedNative(_) => match scheme.value {
102                    QuantValue::E2M1 => core::mem::size_of::<u8>(),
103                    _ => 0,
104                },
105            },
106        }
107    }
108    /// Returns true if the data type is a floating point type.
109    pub fn is_float(&self) -> bool {
110        matches!(
111            self,
112            DType::F64 | DType::F32 | DType::Flex32 | DType::F16 | DType::BF16
113        )
114    }
115    /// Returns true if the data type is a signed integer type.
116    pub fn is_int(&self) -> bool {
117        matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
118    }
119    /// Returns true if the data type is an unsigned integer type.
120    pub fn is_uint(&self) -> bool {
121        matches!(self, DType::U64 | DType::U32 | DType::U16 | DType::U8)
122    }
123
124    /// Returns true if the data type is a boolean type
125    pub fn is_bool(&self) -> bool {
126        matches!(self, DType::Bool(_))
127    }
128
129    /// Returns the data type name.
130    pub fn name(&self) -> &'static str {
131        match self {
132            DType::F64 => "f64",
133            DType::F32 => "f32",
134            DType::Flex32 => "flex32",
135            DType::F16 => "f16",
136            DType::BF16 => "bf16",
137            DType::I64 => "i64",
138            DType::I32 => "i32",
139            DType::I16 => "i16",
140            DType::I8 => "i8",
141            DType::U64 => "u64",
142            DType::U32 => "u32",
143            DType::U16 => "u16",
144            DType::U8 => "u8",
145            DType::Bool(store) => match store {
146                BoolStore::Native => "bool",
147                BoolStore::U8 => "bool(u8)",
148                BoolStore::U32 => "bool(u32)",
149            },
150            DType::QFloat(_) => "qfloat",
151        }
152    }
153}
154
155#[allow(missing_docs)]
156#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
157pub enum FloatDType {
158    F64,
159    F32,
160    Flex32,
161    F16,
162    BF16,
163}
164
165impl From<DType> for FloatDType {
166    fn from(value: DType) -> Self {
167        match value {
168            DType::F64 => FloatDType::F64,
169            DType::F32 => FloatDType::F32,
170            DType::Flex32 => FloatDType::Flex32,
171            DType::F16 => FloatDType::F16,
172            DType::BF16 => FloatDType::BF16,
173            _ => panic!("Expected float data type, got {value:?}"),
174        }
175    }
176}
177
178impl From<FloatDType> for DType {
179    fn from(value: FloatDType) -> Self {
180        match value {
181            FloatDType::F64 => DType::F64,
182            FloatDType::F32 => DType::F32,
183            FloatDType::Flex32 => DType::Flex32,
184            FloatDType::F16 => DType::F16,
185            FloatDType::BF16 => DType::BF16,
186        }
187    }
188}
189
190#[allow(missing_docs)]
191#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
192pub enum IntDType {
193    I64,
194    I32,
195    I16,
196    I8,
197    U64,
198    U32,
199    U16,
200    U8,
201}
202
203impl From<DType> for IntDType {
204    fn from(value: DType) -> Self {
205        match value {
206            DType::I64 => IntDType::I64,
207            DType::I32 => IntDType::I32,
208            DType::I16 => IntDType::I16,
209            DType::I8 => IntDType::I8,
210            DType::U64 => IntDType::U64,
211            DType::U32 => IntDType::U32,
212            DType::U16 => IntDType::U16,
213            DType::U8 => IntDType::U8,
214            _ => panic!("Expected int data type, got {value:?}"),
215        }
216    }
217}
218
219impl From<IntDType> for DType {
220    fn from(value: IntDType) -> Self {
221        match value {
222            IntDType::I64 => DType::I64,
223            IntDType::I32 => DType::I32,
224            IntDType::I16 => DType::I16,
225            IntDType::I8 => DType::I8,
226            IntDType::U64 => DType::U64,
227            IntDType::U32 => DType::U32,
228            IntDType::U16 => DType::U16,
229            IntDType::U8 => DType::U8,
230        }
231    }
232}
233
234#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
235/// Data type used to store boolean values.
236pub enum BoolStore {
237    /// Stored as native boolean type (e.g. `bool`).
238    Native,
239    /// Stored as 8-bit unsigned integer.
240    U8,
241    /// Stored as 32-bit unsigned integer.
242    U32,
243}
244
245/// Boolean dtype.
246///
247/// This is currently an alias to [`BoolStore`], since it only varies by the storage representation.
248pub type BoolDType = BoolStore;
249
250#[allow(deprecated)]
251impl From<DType> for BoolDType {
252    fn from(value: DType) -> Self {
253        match value {
254            DType::Bool(store) => match store {
255                BoolStore::Native => BoolDType::Native,
256                BoolStore::U8 => BoolDType::U8,
257                BoolStore::U32 => BoolDType::U32,
258            },
259            // For compat BoolElem associated type
260            DType::U8 => BoolDType::U8,
261            DType::U32 => BoolDType::U32,
262            _ => panic!("Expected bool data type, got {value:?}"),
263        }
264    }
265}
266
267impl From<BoolDType> for DType {
268    fn from(value: BoolDType) -> Self {
269        match value {
270            BoolDType::Native => DType::Bool(BoolStore::Native),
271            BoolDType::U8 => DType::Bool(BoolStore::U8),
272            BoolDType::U32 => DType::Bool(BoolStore::U32),
273        }
274    }
275}