1#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
5#[repr(i32)]
6pub enum DType {
7 #[default]
8 Unspecified = 0,
9 Bool = 1,
10 Uint8 = 2,
11 Int32 = 3,
12 Int64 = 4,
13 Float16 = 5,
14 Float32 = 6,
15 Float64 = 7,
16 Int8 = 8,
17 Int16 = 9,
18 Uint16 = 10,
19 Uint32 = 11,
20 Uint64 = 12,
21 Bfloat16 = 13,
22}
23
24impl TryFrom<i32> for DType {
25 type Error = &'static str;
26
27 fn try_from(value: i32) -> Result<Self, Self::Error> {
28 match value {
29 0 => Ok(Self::Unspecified),
30 1 => Ok(Self::Bool),
31 2 => Ok(Self::Uint8),
32 3 => Ok(Self::Int32),
33 4 => Ok(Self::Int64),
34 5 => Ok(Self::Float16),
35 6 => Ok(Self::Float32),
36 7 => Ok(Self::Float64),
37 8 => Ok(Self::Int8),
38 9 => Ok(Self::Int16),
39 10 => Ok(Self::Uint16),
40 11 => Ok(Self::Uint32),
41 12 => Ok(Self::Uint64),
42 13 => Ok(Self::Bfloat16),
43 _ => Err("invalid dtype"),
44 }
45 }
46}
47
48impl From<DType> for i32 {
49 fn from(value: DType) -> Self {
50 value as i32
51 }
52}
53
54impl DType {
55 pub const ALL: [DType; 14] = [
57 DType::Unspecified,
58 DType::Bool,
59 DType::Uint8,
60 DType::Int32,
61 DType::Int64,
62 DType::Float16,
63 DType::Float32,
64 DType::Float64,
65 DType::Int8,
66 DType::Int16,
67 DType::Uint16,
68 DType::Uint32,
69 DType::Uint64,
70 DType::Bfloat16,
71 ];
72
73 pub const fn name(self) -> &'static str {
75 match self {
76 DType::Unspecified => "unspecified",
77 DType::Bool => "bool",
78 DType::Uint8 => "uint8",
79 DType::Int32 => "int32",
80 DType::Int64 => "int64",
81 DType::Float16 => "float16",
82 DType::Float32 => "float32",
83 DType::Float64 => "float64",
84 DType::Int8 => "int8",
85 DType::Int16 => "int16",
86 DType::Uint16 => "uint16",
87 DType::Uint32 => "uint32",
88 DType::Uint64 => "uint64",
89 DType::Bfloat16 => "bfloat16",
90 }
91 }
92
93 pub fn from_name(name: &str) -> Option<Self> {
96 match name {
97 "bool" => Some(DType::Bool),
98 "uint8" => Some(DType::Uint8),
99 "int32" => Some(DType::Int32),
100 "int64" => Some(DType::Int64),
101 "float16" => Some(DType::Float16),
102 "float32" => Some(DType::Float32),
103 "float64" => Some(DType::Float64),
104 "int8" => Some(DType::Int8),
105 "int16" => Some(DType::Int16),
106 "uint16" => Some(DType::Uint16),
107 "uint32" => Some(DType::Uint32),
108 "uint64" => Some(DType::Uint64),
109 "bfloat16" => Some(DType::Bfloat16),
110 _ => None,
111 }
112 }
113
114 #[must_use]
117 pub const fn is_integer(self) -> bool {
118 matches!(
119 self,
120 DType::Uint8
121 | DType::Int8
122 | DType::Int16
123 | DType::Uint16
124 | DType::Int32
125 | DType::Uint32
126 | DType::Int64
127 | DType::Uint64
128 )
129 }
130
131 #[must_use]
133 pub const fn is_float(self) -> bool {
134 matches!(
135 self,
136 DType::Float16 | DType::Float32 | DType::Float64 | DType::Bfloat16
137 )
138 }
139}
140
141impl std::fmt::Display for DType {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 f.write_str(self.name())
144 }
145}
146
147pub const fn dtype_size(dtype: DType) -> usize {
149 match dtype {
150 DType::Unspecified => 0,
151 DType::Bool | DType::Uint8 | DType::Int8 => 1,
152 DType::Float16 | DType::Bfloat16 | DType::Int16 | DType::Uint16 => 2,
153 DType::Int32 | DType::Uint32 | DType::Float32 => 4,
154 DType::Int64 | DType::Uint64 | DType::Float64 => 8,
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn test_dtype_i32_roundtrip() {
164 for dtype in DType::ALL {
165 let raw = i32::from(dtype);
166 assert_eq!(DType::try_from(raw), Ok(dtype));
167 }
168 }
169
170 #[test]
171 fn test_dtype_rejects_unknown_values() {
172 assert!(DType::try_from(-1).is_err());
173 assert!(DType::try_from(14).is_err());
174 }
175
176 #[test]
177 fn test_dtype_name_roundtrip() {
178 for dtype in DType::ALL {
179 if dtype == DType::Unspecified {
180 continue;
181 }
182 assert_eq!(DType::from_name(dtype.name()), Some(dtype));
183 assert_eq!(dtype.to_string(), dtype.name());
184 }
185 assert_eq!(DType::Unspecified.name(), "unspecified");
186 assert_eq!(DType::from_name("unspecified"), None);
187 assert_eq!(DType::from_name("complex64"), None);
188 assert_eq!(DType::from_name("Float32"), None);
189 }
190
191 #[test]
192 fn test_dtype_size_table() {
193 let expected = [
194 (DType::Unspecified, 0),
195 (DType::Bool, 1),
196 (DType::Uint8, 1),
197 (DType::Int8, 1),
198 (DType::Float16, 2),
199 (DType::Bfloat16, 2),
200 (DType::Int16, 2),
201 (DType::Uint16, 2),
202 (DType::Int32, 4),
203 (DType::Uint32, 4),
204 (DType::Float32, 4),
205 (DType::Int64, 8),
206 (DType::Uint64, 8),
207 (DType::Float64, 8),
208 ];
209 for (dtype, size) in expected {
210 assert_eq!(dtype_size(dtype), size, "size mismatch for {dtype:?}");
211 }
212 }
213}