Skip to main content

burn_std/tensor/
dtype.rs

1//! Tensor data type.
2
3use serde::{Deserialize, Serialize};
4
5use crate::tensor::quantization::{QuantScheme, QuantStore, QuantValue};
6use crate::{bf16, f16};
7
8#[allow(missing_docs)]
9#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
10pub enum DType {
11    F64,
12    F32,
13    Flex32,
14    F16,
15    BF16,
16    I64,
17    I32,
18    I16,
19    I8,
20    U64,
21    U32,
22    U16,
23    U8,
24    Bool(BoolStore),
25    QFloat(QuantScheme),
26}
27
28#[cfg(feature = "cubecl")]
29impl From<cubecl::ir::ElemType> for DType {
30    fn from(value: cubecl::ir::ElemType) -> Self {
31        match value {
32            cubecl::ir::ElemType::Float(float_kind) => match float_kind {
33                cubecl::ir::FloatKind::F16 => DType::F16,
34                cubecl::ir::FloatKind::BF16 => DType::BF16,
35                cubecl::ir::FloatKind::Flex32 => DType::Flex32,
36                cubecl::ir::FloatKind::F32 => DType::F32,
37                cubecl::ir::FloatKind::F64 => DType::F64,
38                cubecl::ir::FloatKind::TF32 => panic!("Not a valid DType for tensors."),
39                cubecl::ir::FloatKind::E2M1
40                | cubecl::ir::FloatKind::E2M3
41                | cubecl::ir::FloatKind::E3M2
42                | cubecl::ir::FloatKind::E4M3
43                | cubecl::ir::FloatKind::E5M2
44                | cubecl::ir::FloatKind::UE8M0 => {
45                    unimplemented!("Not yet supported, will be used for quantization")
46                }
47            },
48            cubecl::ir::ElemType::Int(int_kind) => match int_kind {
49                cubecl::ir::IntKind::I8 => DType::I8,
50                cubecl::ir::IntKind::I16 => DType::I16,
51                cubecl::ir::IntKind::I32 => DType::I32,
52                cubecl::ir::IntKind::I64 => DType::I64,
53            },
54            cubecl::ir::ElemType::UInt(uint_kind) => match uint_kind {
55                cubecl::ir::UIntKind::U8 => DType::U8,
56                cubecl::ir::UIntKind::U16 => DType::U16,
57                cubecl::ir::UIntKind::U32 => DType::U32,
58                cubecl::ir::UIntKind::U64 => DType::U64,
59            },
60            _ => panic!("Not a valid DType for tensors."),
61        }
62    }
63}
64
65impl DType {
66    /// Returns the size of a type in bytes.
67    pub const fn size(&self) -> usize {
68        match self {
69            DType::F64 => core::mem::size_of::<f64>(),
70            DType::F32 => core::mem::size_of::<f32>(),
71            DType::Flex32 => core::mem::size_of::<f32>(),
72            DType::F16 => core::mem::size_of::<f16>(),
73            DType::BF16 => core::mem::size_of::<bf16>(),
74            DType::I64 => core::mem::size_of::<i64>(),
75            DType::I32 => core::mem::size_of::<i32>(),
76            DType::I16 => core::mem::size_of::<i16>(),
77            DType::I8 => core::mem::size_of::<i8>(),
78            DType::U64 => core::mem::size_of::<u64>(),
79            DType::U32 => core::mem::size_of::<u32>(),
80            DType::U16 => core::mem::size_of::<u16>(),
81            DType::U8 => core::mem::size_of::<u8>(),
82            DType::Bool(store) => match store {
83                BoolStore::Native => core::mem::size_of::<bool>(),
84                BoolStore::U8 => core::mem::size_of::<u8>(),
85                BoolStore::U32 => core::mem::size_of::<u32>(),
86            },
87            DType::QFloat(scheme) => match scheme.store {
88                QuantStore::Native => match scheme.value {
89                    QuantValue::Q8F | QuantValue::Q8S => core::mem::size_of::<i8>(),
90                    // e2m1 native is automatically packed by the kernels, so the actual storage is
91                    // 8 bits wide.
92                    QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
93                        core::mem::size_of::<u8>()
94                    }
95                    QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
96                        // Sub-byte values have fractional size
97                        0
98                    }
99                },
100                QuantStore::PackedU32(_) => core::mem::size_of::<u32>(),
101                QuantStore::PackedNative(_) => match scheme.value {
102                    QuantValue::E2M1 => core::mem::size_of::<u8>(),
103                    _ => 0,
104                },
105            },
106        }
107    }
108    /// Returns true if the data type is a floating point type.
109    pub fn is_float(&self) -> bool {
110        matches!(
111            self,
112            DType::F64 | DType::F32 | DType::Flex32 | DType::F16 | DType::BF16
113        )
114    }
115    /// Returns true if the data type is a signed integer type.
116    pub fn is_int(&self) -> bool {
117        matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
118    }
119    /// Returns true if the data type is an unsigned integer type.
120    pub fn is_uint(&self) -> bool {
121        matches!(self, DType::U64 | DType::U32 | DType::U16 | DType::U8)
122    }
123
124    /// Returns true if the data type is a boolean type
125    pub fn is_bool(&self) -> bool {
126        matches!(self, DType::Bool(_))
127    }
128
129    /// Returns float precision info if this is a float dtype, `None` otherwise.
130    ///
131    /// Analogous to `torch.finfo(dtype)` or `numpy.finfo(dtype)`.
132    pub const fn finfo(&self) -> Option<FloatInfo> {
133        match self {
134            DType::F64 => Some(FloatDType::F64.finfo()),
135            DType::F32 => Some(FloatDType::F32.finfo()),
136            DType::Flex32 => Some(FloatDType::Flex32.finfo()),
137            DType::F16 => Some(FloatDType::F16.finfo()),
138            DType::BF16 => Some(FloatDType::BF16.finfo()),
139            _ => None,
140        }
141    }
142
143    /// Returns the data type name.
144    pub fn name(&self) -> &'static str {
145        match self {
146            DType::F64 => "f64",
147            DType::F32 => "f32",
148            DType::Flex32 => "flex32",
149            DType::F16 => "f16",
150            DType::BF16 => "bf16",
151            DType::I64 => "i64",
152            DType::I32 => "i32",
153            DType::I16 => "i16",
154            DType::I8 => "i8",
155            DType::U64 => "u64",
156            DType::U32 => "u32",
157            DType::U16 => "u16",
158            DType::U8 => "u8",
159            DType::Bool(store) => match store {
160                BoolStore::Native => "bool",
161                BoolStore::U8 => "bool(u8)",
162                BoolStore::U32 => "bool(u32)",
163            },
164            DType::QFloat(_) => "qfloat",
165        }
166    }
167}
168
169#[allow(missing_docs)]
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
171pub enum FloatDType {
172    F64,
173    F32,
174    Flex32,
175    F16,
176    BF16,
177}
178
179/// Numerical precision properties for a floating-point dtype.
180///
181/// Equivalent to NumPy's `finfo` / PyTorch's `torch.finfo`. All values are
182/// widened to `f64` so they can be inspected without knowing the concrete
183/// element type at compile time.
184#[derive(Debug, Clone, Copy, PartialEq)]
185pub struct FloatInfo {
186    /// Machine epsilon: smallest value such that `1.0 + epsilon != 1.0`.
187    pub epsilon: f64,
188    /// Largest representable finite value.
189    pub max: f64,
190    /// Most negative representable finite value.
191    pub min: f64,
192    /// Smallest positive normal value.
193    pub min_positive: f64,
194}
195
196impl FloatDType {
197    /// Returns numerical precision properties for this float dtype.
198    ///
199    /// Analogous to `torch.finfo(dtype)` or `numpy.finfo(dtype)`.
200    pub const fn finfo(self) -> FloatInfo {
201        match self {
202            FloatDType::F64 => FloatInfo {
203                epsilon: f64::EPSILON,
204                max: f64::MAX,
205                min: f64::MIN,
206                min_positive: f64::MIN_POSITIVE, // ~2.225e-308
207            },
208            FloatDType::F32 => FloatInfo {
209                epsilon: f32::EPSILON as f64,
210                max: f32::MAX as f64,
211                min: f32::MIN as f64,
212                min_positive: f32::MIN_POSITIVE as f64, // ~1.175e-38
213            },
214            // Flex32 stores as f32 but computes at reduced (f16-like) precision.
215            // Use f16 precision limits so stability code stays safe.
216            FloatDType::Flex32 => FloatInfo {
217                epsilon: f16::EPSILON.to_f64_const(),
218                max: f16::MAX.to_f64_const(),
219                min: f16::MIN.to_f64_const(),
220                min_positive: f16::MIN_POSITIVE.to_f64_const(), // ~6.104e-5
221            },
222            FloatDType::F16 => FloatInfo {
223                epsilon: f16::EPSILON.to_f64_const(),
224                max: f16::MAX.to_f64_const(),
225                min: f16::MIN.to_f64_const(),
226                min_positive: f16::MIN_POSITIVE.to_f64_const(), // ~6.104e-5
227            },
228            FloatDType::BF16 => FloatInfo {
229                epsilon: bf16::EPSILON.to_f64_const(),
230                max: bf16::MAX.to_f64_const(),
231                min: bf16::MIN.to_f64_const(),
232                min_positive: bf16::MIN_POSITIVE.to_f64_const(), // ~1.175e-38
233            },
234        }
235    }
236}
237
238impl From<DType> for FloatDType {
239    fn from(value: DType) -> Self {
240        match value {
241            DType::F64 => FloatDType::F64,
242            DType::F32 => FloatDType::F32,
243            DType::Flex32 => FloatDType::Flex32,
244            DType::F16 => FloatDType::F16,
245            DType::BF16 => FloatDType::BF16,
246            _ => panic!("Expected float data type, got {value:?}"),
247        }
248    }
249}
250
251impl From<FloatDType> for DType {
252    fn from(value: FloatDType) -> Self {
253        match value {
254            FloatDType::F64 => DType::F64,
255            FloatDType::F32 => DType::F32,
256            FloatDType::Flex32 => DType::Flex32,
257            FloatDType::F16 => DType::F16,
258            FloatDType::BF16 => DType::BF16,
259        }
260    }
261}
262
263#[allow(missing_docs)]
264#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
265pub enum IntDType {
266    I64,
267    I32,
268    I16,
269    I8,
270    U64,
271    U32,
272    U16,
273    U8,
274}
275
276impl From<DType> for IntDType {
277    fn from(value: DType) -> Self {
278        match value {
279            DType::I64 => IntDType::I64,
280            DType::I32 => IntDType::I32,
281            DType::I16 => IntDType::I16,
282            DType::I8 => IntDType::I8,
283            DType::U64 => IntDType::U64,
284            DType::U32 => IntDType::U32,
285            DType::U16 => IntDType::U16,
286            DType::U8 => IntDType::U8,
287            _ => panic!("Expected int data type, got {value:?}"),
288        }
289    }
290}
291
292impl From<IntDType> for DType {
293    fn from(value: IntDType) -> Self {
294        match value {
295            IntDType::I64 => DType::I64,
296            IntDType::I32 => DType::I32,
297            IntDType::I16 => DType::I16,
298            IntDType::I8 => DType::I8,
299            IntDType::U64 => DType::U64,
300            IntDType::U32 => DType::U32,
301            IntDType::U16 => DType::U16,
302            IntDType::U8 => DType::U8,
303        }
304    }
305}
306
307#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
308/// Data type used to store boolean values.
309pub enum BoolStore {
310    /// Stored as native boolean type (e.g. `bool`).
311    Native,
312    /// Stored as 8-bit unsigned integer.
313    U8,
314    /// Stored as 32-bit unsigned integer.
315    U32,
316}
317
318/// Boolean dtype.
319///
320/// This is currently an alias to [`BoolStore`], since it only varies by the storage representation.
321pub type BoolDType = BoolStore;
322
323#[allow(deprecated)]
324impl From<DType> for BoolDType {
325    fn from(value: DType) -> Self {
326        match value {
327            DType::Bool(store) => match store {
328                BoolStore::Native => BoolDType::Native,
329                BoolStore::U8 => BoolDType::U8,
330                BoolStore::U32 => BoolDType::U32,
331            },
332            // For compat BoolElem associated type
333            DType::U8 => BoolDType::U8,
334            DType::U32 => BoolDType::U32,
335            _ => panic!("Expected bool data type, got {value:?}"),
336        }
337    }
338}
339
340impl From<BoolDType> for DType {
341    fn from(value: BoolDType) -> Self {
342        match value {
343            BoolDType::Native => DType::Bool(BoolStore::Native),
344            BoolDType::U8 => DType::Bool(BoolStore::U8),
345            BoolDType::U32 => DType::Bool(BoolStore::U32),
346        }
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn finfo_f32() {
356        let info = FloatDType::F32.finfo();
357        assert_eq!(info.epsilon, f32::EPSILON as f64);
358        assert_eq!(info.max, f32::MAX as f64);
359        assert_eq!(info.min, f32::MIN as f64);
360        assert_eq!(info.min_positive, f32::MIN_POSITIVE as f64);
361    }
362
363    #[test]
364    fn finfo_f64() {
365        let info = FloatDType::F64.finfo();
366        assert_eq!(info.epsilon, f64::EPSILON);
367        assert_eq!(info.max, f64::MAX);
368        assert_eq!(info.min, f64::MIN);
369        assert_eq!(info.min_positive, f64::MIN_POSITIVE);
370    }
371
372    #[test]
373    fn finfo_f16() {
374        let info = FloatDType::F16.finfo();
375        assert_eq!(info.epsilon, f16::EPSILON.to_f64_const());
376        assert!(info.epsilon > 0.0);
377        assert!(info.min_positive > 0.0);
378        // f16 epsilon is much larger than f32
379        assert!(info.epsilon > FloatDType::F32.finfo().epsilon);
380    }
381
382    #[test]
383    fn finfo_bf16() {
384        let info = FloatDType::BF16.finfo();
385        assert_eq!(info.epsilon, bf16::EPSILON.to_f64_const());
386        assert!(info.epsilon > 0.0);
387        assert!(info.min_positive > 0.0);
388        // bf16 epsilon is larger than f32 (fewer mantissa bits)
389        assert!(info.epsilon > FloatDType::F32.finfo().epsilon);
390    }
391
392    #[test]
393    fn finfo_flex32_uses_f16_limits() {
394        let flex = FloatDType::Flex32.finfo();
395        let f16_info = FloatDType::F16.finfo();
396        assert_eq!(flex.epsilon, f16_info.epsilon);
397        assert_eq!(flex.min_positive, f16_info.min_positive);
398    }
399
400    #[test]
401    fn dtype_finfo_delegates_to_float_dtype() {
402        assert_eq!(DType::F32.finfo(), Some(FloatDType::F32.finfo()));
403        assert_eq!(DType::F64.finfo(), Some(FloatDType::F64.finfo()));
404        assert_eq!(DType::F16.finfo(), Some(FloatDType::F16.finfo()));
405        assert_eq!(DType::BF16.finfo(), Some(FloatDType::BF16.finfo()));
406        assert_eq!(DType::Flex32.finfo(), Some(FloatDType::Flex32.finfo()));
407    }
408
409    #[test]
410    fn dtype_finfo_returns_none_for_non_float() {
411        assert!(DType::I32.finfo().is_none());
412        assert!(DType::U8.finfo().is_none());
413        assert!(DType::Bool(BoolStore::Native).finfo().is_none());
414    }
415
416    #[test]
417    fn finfo_invariants() {
418        for dtype in [
419            FloatDType::F64,
420            FloatDType::F32,
421            FloatDType::F16,
422            FloatDType::BF16,
423            FloatDType::Flex32,
424        ] {
425            let info = dtype.finfo();
426            assert!(info.epsilon > 0.0, "{dtype:?}: epsilon must be positive");
427            assert!(
428                info.min_positive > 0.0,
429                "{dtype:?}: min_positive must be positive"
430            );
431            assert!(info.max > 0.0, "{dtype:?}: max must be positive");
432            assert!(info.min < 0.0, "{dtype:?}: min must be negative");
433            assert!(
434                info.max > info.min_positive,
435                "{dtype:?}: max > min_positive"
436            );
437        }
438    }
439}