Skip to main content

rlmesh_spaces/
dtype.rs

1/// Element data type for tensor values exchanged across the wire.
2///
3/// Discriminants match the `rlmesh.spaces.v1.DType` proto enum.
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
5#[repr(i32)]
6pub enum DType {
7    #[default]
8    Unspecified = 0,
9    Bool = 1,
10    Uint8 = 2,
11    Int32 = 3,
12    Int64 = 4,
13    Float16 = 5,
14    Float32 = 6,
15    Float64 = 7,
16    Int8 = 8,
17    Int16 = 9,
18    Uint16 = 10,
19    Uint32 = 11,
20    Uint64 = 12,
21    Bfloat16 = 13,
22}
23
24impl TryFrom<i32> for DType {
25    type Error = &'static str;
26
27    fn try_from(value: i32) -> Result<Self, Self::Error> {
28        match value {
29            0 => Ok(Self::Unspecified),
30            1 => Ok(Self::Bool),
31            2 => Ok(Self::Uint8),
32            3 => Ok(Self::Int32),
33            4 => Ok(Self::Int64),
34            5 => Ok(Self::Float16),
35            6 => Ok(Self::Float32),
36            7 => Ok(Self::Float64),
37            8 => Ok(Self::Int8),
38            9 => Ok(Self::Int16),
39            10 => Ok(Self::Uint16),
40            11 => Ok(Self::Uint32),
41            12 => Ok(Self::Uint64),
42            13 => Ok(Self::Bfloat16),
43            _ => Err("invalid dtype"),
44        }
45    }
46}
47
48impl From<DType> for i32 {
49    fn from(value: DType) -> Self {
50        value as i32
51    }
52}
53
54impl DType {
55    /// Every dtype, including `Unspecified`, in discriminant order.
56    pub const ALL: [DType; 14] = [
57        DType::Unspecified,
58        DType::Bool,
59        DType::Uint8,
60        DType::Int32,
61        DType::Int64,
62        DType::Float16,
63        DType::Float32,
64        DType::Float64,
65        DType::Int8,
66        DType::Int16,
67        DType::Uint16,
68        DType::Uint32,
69        DType::Uint64,
70        DType::Bfloat16,
71    ];
72
73    /// The canonical lowercase dtype name (for example `"float32"`).
74    pub const fn name(self) -> &'static str {
75        match self {
76            DType::Unspecified => "unspecified",
77            DType::Bool => "bool",
78            DType::Uint8 => "uint8",
79            DType::Int32 => "int32",
80            DType::Int64 => "int64",
81            DType::Float16 => "float16",
82            DType::Float32 => "float32",
83            DType::Float64 => "float64",
84            DType::Int8 => "int8",
85            DType::Int16 => "int16",
86            DType::Uint16 => "uint16",
87            DType::Uint32 => "uint32",
88            DType::Uint64 => "uint64",
89            DType::Bfloat16 => "bfloat16",
90        }
91    }
92
93    /// Parse a canonical dtype name. Only the 13 concrete dtypes are
94    /// recognized; `"unspecified"` and unknown names return `None`.
95    pub fn from_name(name: &str) -> Option<Self> {
96        match name {
97            "bool" => Some(DType::Bool),
98            "uint8" => Some(DType::Uint8),
99            "int32" => Some(DType::Int32),
100            "int64" => Some(DType::Int64),
101            "float16" => Some(DType::Float16),
102            "float32" => Some(DType::Float32),
103            "float64" => Some(DType::Float64),
104            "int8" => Some(DType::Int8),
105            "int16" => Some(DType::Int16),
106            "uint16" => Some(DType::Uint16),
107            "uint32" => Some(DType::Uint32),
108            "uint64" => Some(DType::Uint64),
109            "bfloat16" => Some(DType::Bfloat16),
110            _ => None,
111        }
112    }
113
114    /// Whether this is an integer dtype (signed or unsigned). `Bool` and the
115    /// float dtypes are not integers.
116    #[must_use]
117    pub const fn is_integer(self) -> bool {
118        matches!(
119            self,
120            DType::Uint8
121                | DType::Int8
122                | DType::Int16
123                | DType::Uint16
124                | DType::Int32
125                | DType::Uint32
126                | DType::Int64
127                | DType::Uint64
128        )
129    }
130
131    /// Whether this is a floating-point dtype.
132    #[must_use]
133    pub const fn is_float(self) -> bool {
134        matches!(
135            self,
136            DType::Float16 | DType::Float32 | DType::Float64 | DType::Bfloat16
137        )
138    }
139}
140
141impl std::fmt::Display for DType {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        f.write_str(self.name())
144    }
145}
146
147/// Get the byte size of a dtype. `Unspecified` has no size and returns 0.
148pub const fn dtype_size(dtype: DType) -> usize {
149    match dtype {
150        DType::Unspecified => 0,
151        DType::Bool | DType::Uint8 | DType::Int8 => 1,
152        DType::Float16 | DType::Bfloat16 | DType::Int16 | DType::Uint16 => 2,
153        DType::Int32 | DType::Uint32 | DType::Float32 => 4,
154        DType::Int64 | DType::Uint64 | DType::Float64 => 8,
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn test_dtype_i32_roundtrip() {
164        for dtype in DType::ALL {
165            let raw = i32::from(dtype);
166            assert_eq!(DType::try_from(raw), Ok(dtype));
167        }
168    }
169
170    #[test]
171    fn test_dtype_rejects_unknown_values() {
172        assert!(DType::try_from(-1).is_err());
173        assert!(DType::try_from(14).is_err());
174    }
175
176    #[test]
177    fn test_dtype_name_roundtrip() {
178        for dtype in DType::ALL {
179            if dtype == DType::Unspecified {
180                continue;
181            }
182            assert_eq!(DType::from_name(dtype.name()), Some(dtype));
183            assert_eq!(dtype.to_string(), dtype.name());
184        }
185        assert_eq!(DType::Unspecified.name(), "unspecified");
186        assert_eq!(DType::from_name("unspecified"), None);
187        assert_eq!(DType::from_name("complex64"), None);
188        assert_eq!(DType::from_name("Float32"), None);
189    }
190
191    #[test]
192    fn test_dtype_size_table() {
193        let expected = [
194            (DType::Unspecified, 0),
195            (DType::Bool, 1),
196            (DType::Uint8, 1),
197            (DType::Int8, 1),
198            (DType::Float16, 2),
199            (DType::Bfloat16, 2),
200            (DType::Int16, 2),
201            (DType::Uint16, 2),
202            (DType::Int32, 4),
203            (DType::Uint32, 4),
204            (DType::Float32, 4),
205            (DType::Int64, 8),
206            (DType::Uint64, 8),
207            (DType::Float64, 8),
208        ];
209        for (dtype, size) in expected {
210            assert_eq!(dtype_size(dtype), size, "size mismatch for {dtype:?}");
211        }
212    }
213}