Skip to main content

mlx_native/
dtypes.rs

1//! Data-type enumeration for tensor elements.
2
3use std::fmt;
4
5/// Element data type carried by an [`MlxBuffer`](crate::MlxBuffer).
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
7pub enum DType {
8    /// 32-bit IEEE 754 float.
9    F32,
10    /// 16-bit IEEE 754 half-precision float.
11    F16,
12    /// 16-bit brain floating point.
13    BF16,
14    /// Unsigned 8-bit integer.
15    U8,
16    /// Unsigned 16-bit integer.
17    U16,
18    /// Unsigned 32-bit integer.
19    U32,
20    /// Signed 32-bit integer.
21    I32,
22}
23
24impl DType {
25    /// Size of a single element in bytes.
26    #[inline]
27    pub fn size_of(self) -> usize {
28        match self {
29            DType::F32 | DType::U32 | DType::I32 => 4,
30            DType::F16 | DType::BF16 | DType::U16 => 2,
31            DType::U8 => 1,
32        }
33    }
34
35    /// Short lowercase name, e.g. `"f32"`, `"bf16"`.
36    #[inline]
37    pub fn name(self) -> &'static str {
38        match self {
39            DType::F32 => "f32",
40            DType::F16 => "f16",
41            DType::BF16 => "bf16",
42            DType::U8 => "u8",
43            DType::U16 => "u16",
44            DType::U32 => "u32",
45            DType::I32 => "i32",
46        }
47    }
48}
49
50impl fmt::Display for DType {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        f.write_str(self.name())
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59
60    #[test]
61    fn test_dtype_size_of() {
62        assert_eq!(DType::F32.size_of(), 4);
63        assert_eq!(DType::F16.size_of(), 2);
64        assert_eq!(DType::BF16.size_of(), 2);
65        assert_eq!(DType::U8.size_of(), 1);
66        assert_eq!(DType::U16.size_of(), 2);
67        assert_eq!(DType::U32.size_of(), 4);
68        assert_eq!(DType::I32.size_of(), 4);
69    }
70
71    #[test]
72    fn test_dtype_display() {
73        assert_eq!(format!("{}", DType::F32), "f32");
74        assert_eq!(format!("{}", DType::BF16), "bf16");
75        assert_eq!(format!("{}", DType::U8), "u8");
76    }
77}