burn_std/tensor/
dtype.rs

1//! Tensor data type.
2
3use serde::{Deserialize, Serialize};
4
5use crate::tensor::quantization::{QuantScheme, QuantStore, QuantValue};
6use crate::{bf16, f16};
7
8#[allow(missing_docs)]
9#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
10pub enum DType {
11    F64,
12    F32,
13    Flex32,
14    F16,
15    BF16,
16    I64,
17    I32,
18    I16,
19    I8,
20    U64,
21    U32,
22    U16,
23    U8,
24    Bool,
25    QFloat(QuantScheme),
26}
27
28#[cfg(feature = "cubecl")]
29impl From<cubecl::ir::ElemType> for DType {
30    fn from(value: cubecl::ir::ElemType) -> Self {
31        match value {
32            cubecl::ir::ElemType::Float(float_kind) => match float_kind {
33                cubecl::ir::FloatKind::F16 => DType::F16,
34                cubecl::ir::FloatKind::BF16 => DType::BF16,
35                cubecl::ir::FloatKind::Flex32 => DType::Flex32,
36                cubecl::ir::FloatKind::F32 => DType::F32,
37                cubecl::ir::FloatKind::F64 => DType::F64,
38                cubecl::ir::FloatKind::TF32 => panic!("Not a valid DType for tensors."),
39                cubecl::ir::FloatKind::E2M1
40                | cubecl::ir::FloatKind::E2M3
41                | cubecl::ir::FloatKind::E3M2
42                | cubecl::ir::FloatKind::E4M3
43                | cubecl::ir::FloatKind::E5M2
44                | cubecl::ir::FloatKind::UE8M0 => {
45                    unimplemented!("Not yet supported, will be used for quantization")
46                }
47            },
48            cubecl::ir::ElemType::Int(int_kind) => match int_kind {
49                cubecl::ir::IntKind::I8 => DType::I8,
50                cubecl::ir::IntKind::I16 => DType::I16,
51                cubecl::ir::IntKind::I32 => DType::I32,
52                cubecl::ir::IntKind::I64 => DType::I64,
53            },
54            cubecl::ir::ElemType::UInt(uint_kind) => match uint_kind {
55                cubecl::ir::UIntKind::U8 => DType::U8,
56                cubecl::ir::UIntKind::U16 => DType::U16,
57                cubecl::ir::UIntKind::U32 => DType::U32,
58                cubecl::ir::UIntKind::U64 => DType::U64,
59            },
60            _ => panic!("Not a valid DType for tensors."),
61        }
62    }
63}
64
65impl DType {
66    /// Returns the size of a type in bytes.
67    pub const fn size(&self) -> usize {
68        match self {
69            DType::F64 => core::mem::size_of::<f64>(),
70            DType::F32 => core::mem::size_of::<f32>(),
71            DType::Flex32 => core::mem::size_of::<f32>(),
72            DType::F16 => core::mem::size_of::<f16>(),
73            DType::BF16 => core::mem::size_of::<bf16>(),
74            DType::I64 => core::mem::size_of::<i64>(),
75            DType::I32 => core::mem::size_of::<i32>(),
76            DType::I16 => core::mem::size_of::<i16>(),
77            DType::I8 => core::mem::size_of::<i8>(),
78            DType::U64 => core::mem::size_of::<u64>(),
79            DType::U32 => core::mem::size_of::<u32>(),
80            DType::U16 => core::mem::size_of::<u16>(),
81            DType::U8 => core::mem::size_of::<u8>(),
82            DType::Bool => core::mem::size_of::<bool>(),
83            DType::QFloat(scheme) => match scheme.store {
84                QuantStore::Native => match scheme.value {
85                    QuantValue::Q8F | QuantValue::Q8S => core::mem::size_of::<i8>(),
86                    // e2m1 native is automatically packed by the kernels, so the actual storage is
87                    // 8 bits wide.
88                    QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
89                        core::mem::size_of::<u8>()
90                    }
91                    QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
92                        // Sub-byte values have fractional size
93                        0
94                    }
95                },
96                QuantStore::U32 => core::mem::size_of::<u32>(),
97            },
98        }
99    }
100    /// Returns true if the data type is a floating point type.
101    pub fn is_float(&self) -> bool {
102        matches!(
103            self,
104            DType::F64 | DType::F32 | DType::Flex32 | DType::F16 | DType::BF16
105        )
106    }
107    /// Returns true if the data type is a signed integer type.
108    pub fn is_int(&self) -> bool {
109        matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
110    }
111    /// Returns true if the data type is an unsigned integer type.
112    pub fn is_uint(&self) -> bool {
113        matches!(self, DType::U64 | DType::U32 | DType::U16 | DType::U8)
114    }
115
116    /// Returns true if the data type is a boolean type
117    pub fn is_bool(&self) -> bool {
118        matches!(self, DType::Bool)
119    }
120
121    /// Returns the data type name.
122    pub fn name(&self) -> &'static str {
123        match self {
124            DType::F64 => "f64",
125            DType::F32 => "f32",
126            DType::Flex32 => "flex32",
127            DType::F16 => "f16",
128            DType::BF16 => "bf16",
129            DType::I64 => "i64",
130            DType::I32 => "i32",
131            DType::I16 => "i16",
132            DType::I8 => "i8",
133            DType::U64 => "u64",
134            DType::U32 => "u32",
135            DType::U16 => "u16",
136            DType::U8 => "u8",
137            DType::Bool => "bool",
138            DType::QFloat(_) => "qfloat",
139        }
140    }
141}
142
143#[allow(missing_docs)]
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
145pub enum FloatDType {
146    F64,
147    F32,
148    Flex32,
149    F16,
150    BF16,
151}
152
153impl From<DType> for FloatDType {
154    fn from(value: DType) -> Self {
155        match value {
156            DType::F64 => FloatDType::F64,
157            DType::F32 => FloatDType::F32,
158            DType::Flex32 => FloatDType::Flex32,
159            DType::F16 => FloatDType::F16,
160            DType::BF16 => FloatDType::BF16,
161            _ => panic!("Expected float data type, got {value:?}"),
162        }
163    }
164}
165
166impl From<FloatDType> for DType {
167    fn from(value: FloatDType) -> Self {
168        match value {
169            FloatDType::F64 => DType::F64,
170            FloatDType::F32 => DType::F32,
171            FloatDType::Flex32 => DType::Flex32,
172            FloatDType::F16 => DType::F16,
173            FloatDType::BF16 => DType::BF16,
174        }
175    }
176}
177
178#[allow(missing_docs)]
179#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
180pub enum IntDType {
181    I64,
182    I32,
183    I16,
184    I8,
185    U64,
186    U32,
187    U16,
188    U8,
189}
190
191impl From<DType> for IntDType {
192    fn from(value: DType) -> Self {
193        match value {
194            DType::I64 => IntDType::I64,
195            DType::I32 => IntDType::I32,
196            DType::I16 => IntDType::I16,
197            DType::I8 => IntDType::I8,
198            DType::U64 => IntDType::U64,
199            DType::U32 => IntDType::U32,
200            DType::U16 => IntDType::U16,
201            DType::U8 => IntDType::U8,
202            _ => panic!("Expected int data type, got {value:?}"),
203        }
204    }
205}
206
207impl From<IntDType> for DType {
208    fn from(value: IntDType) -> Self {
209        match value {
210            IntDType::I64 => DType::I64,
211            IntDType::I32 => DType::I32,
212            IntDType::I16 => DType::I16,
213            IntDType::I8 => DType::I8,
214            IntDType::U64 => DType::U64,
215            IntDType::U32 => DType::U32,
216            IntDType::U16 => DType::U16,
217            IntDType::U8 => DType::U8,
218        }
219    }
220}