1use std::fmt;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
7pub enum DType {
8 F32,
10 F16,
12 BF16,
14 U8,
16 U16,
18 U32,
20 I32,
22}
23
24impl DType {
25 #[inline]
27 pub fn size_of(self) -> usize {
28 match self {
29 DType::F32 | DType::U32 | DType::I32 => 4,
30 DType::F16 | DType::BF16 | DType::U16 => 2,
31 DType::U8 => 1,
32 }
33 }
34
35 #[inline]
37 pub fn name(self) -> &'static str {
38 match self {
39 DType::F32 => "f32",
40 DType::F16 => "f16",
41 DType::BF16 => "bf16",
42 DType::U8 => "u8",
43 DType::U16 => "u16",
44 DType::U32 => "u32",
45 DType::I32 => "i32",
46 }
47 }
48}
49
50impl fmt::Display for DType {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f.write_str(self.name())
53 }
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59
60 #[test]
61 fn test_dtype_size_of() {
62 assert_eq!(DType::F32.size_of(), 4);
63 assert_eq!(DType::F16.size_of(), 2);
64 assert_eq!(DType::BF16.size_of(), 2);
65 assert_eq!(DType::U8.size_of(), 1);
66 assert_eq!(DType::U16.size_of(), 2);
67 assert_eq!(DType::U32.size_of(), 4);
68 assert_eq!(DType::I32.size_of(), 4);
69 }
70
71 #[test]
72 fn test_dtype_display() {
73 assert_eq!(format!("{}", DType::F32), "f32");
74 assert_eq!(format!("{}", DType::BF16), "bf16");
75 assert_eq!(format!("{}", DType::U8), "u8");
76 }
77}