1use std::fmt;
4
5#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
10pub enum DType {
11 Float32,
13 Float16,
15 Float64,
17 Int8,
19 Int16,
21 Int32,
23 Int64,
25 Uint8,
27 Uint16,
29 Uint32,
31 Uint64,
33 Bool,
35}
36
37impl DType {
38 #[inline]
40 pub const fn byte_width(self) -> usize {
41 match self {
42 DType::Bool | DType::Int8 | DType::Uint8 => 1,
43 DType::Float16 | DType::Int16 | DType::Uint16 => 2,
44 DType::Float32 | DType::Int32 | DType::Uint32 => 4,
45 DType::Float64 | DType::Int64 | DType::Uint64 => 8,
46 }
47 }
48
49 #[inline]
51 pub const fn is_float(self) -> bool {
52 matches!(self, DType::Float32 | DType::Float16 | DType::Float64)
53 }
54
55 #[inline]
57 pub const fn is_int(self) -> bool {
58 matches!(
59 self,
60 DType::Int8
61 | DType::Int16
62 | DType::Int32
63 | DType::Int64
64 | DType::Uint8
65 | DType::Uint16
66 | DType::Uint32
67 | DType::Uint64
68 )
69 }
70
71 #[inline]
73 pub const fn is_signed(self) -> bool {
74 matches!(
75 self,
76 DType::Int8 | DType::Int16 | DType::Int32 | DType::Int64
77 )
78 }
79
80 #[inline]
82 pub const fn is_unsigned(self) -> bool {
83 matches!(
84 self,
85 DType::Uint8 | DType::Uint16 | DType::Uint32 | DType::Uint64
86 )
87 }
88
89 pub fn promote(dtype1: DType, dtype2: DType) -> DType {
103 if dtype1 == dtype2 {
104 return dtype1;
105 }
106
107 let rank = |d: DType| match d {
109 DType::Bool => 0,
110 DType::Uint8 => 1,
111 DType::Uint16 => 2,
112 DType::Uint32 => 3,
113 DType::Uint64 => 4,
114 DType::Int8 => 5,
115 DType::Int16 => 6,
116 DType::Int32 => 7,
117 DType::Int64 => 8,
118 DType::Float16 => 9,
119 DType::Float32 => 10,
120 DType::Float64 => 11,
121 };
122
123 if rank(dtype1) > rank(dtype2) {
124 dtype1
125 } else {
126 dtype2
127 }
128 }
129
130 #[inline]
132 pub fn cast_from_f32(self, value: f32) -> f32 {
133 match self {
134 DType::Float32 | DType::Float64 | DType::Float16 => value,
135 DType::Int8 => (value as i8) as f32,
136 DType::Int16 => (value as i16) as f32,
137 DType::Int32 | DType::Int64 => (value as i32) as f32,
138 DType::Uint8 => (value as u8) as f32,
139 DType::Uint16 => (value as u16) as f32,
140 DType::Uint32 | DType::Uint64 => (value as u32) as f32,
141 DType::Bool => if value != 0.0 { 1.0 } else { 0.0 },
142 }
143 }
144}
145
146impl fmt::Display for DType {
147 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148 match self {
149 DType::Float32 => write!(f, "float32"),
150 DType::Float16 => write!(f, "float16"),
151 DType::Float64 => write!(f, "float64"),
152 DType::Int8 => write!(f, "int8"),
153 DType::Int16 => write!(f, "int16"),
154 DType::Int32 => write!(f, "int32"),
155 DType::Int64 => write!(f, "int64"),
156 DType::Uint8 => write!(f, "uint8"),
157 DType::Uint16 => write!(f, "uint16"),
158 DType::Uint32 => write!(f, "uint32"),
159 DType::Uint64 => write!(f, "uint64"),
160 DType::Bool => write!(f, "bool"),
161 }
162 }
163}
164
165impl DType {
166 pub fn from_str(s: &str) -> Option<DType> {
168 match s.to_lowercase().as_str() {
169 "float32" | "f32" => Some(DType::Float32),
170 "float16" | "f16" => Some(DType::Float16),
171 "float64" | "f64" => Some(DType::Float64),
172 "int8" | "i8" => Some(DType::Int8),
173 "int16" | "i16" => Some(DType::Int16),
174 "int32" | "i32" => Some(DType::Int32),
175 "int64" | "i64" => Some(DType::Int64),
176 "uint8" | "u8" => Some(DType::Uint8),
177 "uint16" | "u16" => Some(DType::Uint16),
178 "uint32" | "u32" => Some(DType::Uint32),
179 "uint64" | "u64" => Some(DType::Uint64),
180 "bool" => Some(DType::Bool),
181 _ => None,
182 }
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn test_byte_width() {
192 assert_eq!(DType::Float32.byte_width(), 4);
193 assert_eq!(DType::Float16.byte_width(), 2);
194 assert_eq!(DType::Float64.byte_width(), 8);
195 assert_eq!(DType::Int8.byte_width(), 1);
196 assert_eq!(DType::Int16.byte_width(), 2);
197 assert_eq!(DType::Int32.byte_width(), 4);
198 assert_eq!(DType::Int64.byte_width(), 8);
199 assert_eq!(DType::Uint8.byte_width(), 1);
200 assert_eq!(DType::Uint16.byte_width(), 2);
201 assert_eq!(DType::Uint32.byte_width(), 4);
202 assert_eq!(DType::Uint64.byte_width(), 8);
203 assert_eq!(DType::Bool.byte_width(), 1);
204 }
205
206 #[test]
207 fn test_is_float() {
208 assert!(DType::Float32.is_float());
209 assert!(DType::Float16.is_float());
210 assert!(DType::Float64.is_float());
211 assert!(!DType::Int32.is_float());
212 assert!(!DType::Uint32.is_float());
213 assert!(!DType::Bool.is_float());
214 }
215
216 #[test]
217 fn test_is_int() {
218 assert!(DType::Int8.is_int());
219 assert!(DType::Int16.is_int());
220 assert!(DType::Int32.is_int());
221 assert!(DType::Int64.is_int());
222 assert!(DType::Uint8.is_int());
223 assert!(DType::Uint16.is_int());
224 assert!(DType::Uint32.is_int());
225 assert!(DType::Uint64.is_int());
226 assert!(!DType::Float32.is_int());
227 assert!(!DType::Bool.is_int());
228 }
229
230 #[test]
231 fn test_type_promotion() {
232 assert_eq!(DType::promote(DType::Bool, DType::Int32), DType::Int32);
233 assert_eq!(DType::promote(DType::Uint32, DType::Int32), DType::Int32);
234 assert_eq!(
235 DType::promote(DType::Int32, DType::Float16),
236 DType::Float16
237 );
238 assert_eq!(
239 DType::promote(DType::Float16, DType::Float32),
240 DType::Float32
241 );
242 assert_eq!(
243 DType::promote(DType::Uint32, DType::Float32),
244 DType::Float32
245 );
246 assert_eq!(
247 DType::promote(DType::Float32, DType::Float32),
248 DType::Float32
249 );
250 assert_eq!(DType::promote(DType::Uint8, DType::Uint16), DType::Uint16);
251 assert_eq!(DType::promote(DType::Int8, DType::Int16), DType::Int16);
252 }
253
254 #[test]
255 fn test_display() {
256 assert_eq!(DType::Float32.to_string(), "float32");
257 assert_eq!(DType::Float16.to_string(), "float16");
258 assert_eq!(DType::Float64.to_string(), "float64");
259 assert_eq!(DType::Int8.to_string(), "int8");
260 assert_eq!(DType::Int16.to_string(), "int16");
261 assert_eq!(DType::Int32.to_string(), "int32");
262 assert_eq!(DType::Int64.to_string(), "int64");
263 assert_eq!(DType::Uint8.to_string(), "uint8");
264 assert_eq!(DType::Uint16.to_string(), "uint16");
265 assert_eq!(DType::Uint32.to_string(), "uint32");
266 assert_eq!(DType::Uint64.to_string(), "uint64");
267 assert_eq!(DType::Bool.to_string(), "bool");
268 }
269
270 #[test]
271 fn test_from_str() {
272 assert_eq!(DType::from_str("float32"), Some(DType::Float32));
273 assert_eq!(DType::from_str("f32"), Some(DType::Float32));
274 assert_eq!(DType::from_str("int8"), Some(DType::Int8));
275 assert_eq!(DType::from_str("i8"), Some(DType::Int8));
276 assert_eq!(DType::from_str("uint16"), Some(DType::Uint16));
277 assert_eq!(DType::from_str("bool"), Some(DType::Bool));
278 assert_eq!(DType::from_str("unknown"), None);
279 }
280
281 #[test]
282 fn test_cast_from_f32() {
283 assert_eq!(DType::Int8.cast_from_f32(127.5), 127.0);
284 assert_eq!(DType::Int8.cast_from_f32(-128.0), -128.0);
285 assert_eq!(DType::Uint8.cast_from_f32(255.5), 255.0);
286 assert_eq!(DType::Bool.cast_from_f32(0.0), 0.0);
287 assert_eq!(DType::Bool.cast_from_f32(42.0), 1.0);
288 }
289}