1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
7pub enum Device {
8 CPU,
10 CUDA(usize),
12 ROCm(usize),
14 #[cfg(any(target_os = "macos", target_os = "ios"))]
16 Metal,
17}
18
19impl Device {
20 pub fn is_gpu(&self) -> bool {
22 matches!(self, Device::CUDA(_) | Device::ROCm(_)) || {
23 #[cfg(any(target_os = "macos", target_os = "ios"))]
24 {
25 matches!(self, Device::Metal)
26 }
27 #[cfg(not(any(target_os = "macos", target_os = "ios")))]
28 {
29 false
30 }
31 }
32 }
33
34 pub fn index(&self) -> Option<usize> {
36 match self {
37 Device::CUDA(idx) | Device::ROCm(idx) => Some(*idx),
38 _ => None,
39 }
40 }
41}
42
43impl std::fmt::Display for Device {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 match self {
46 Device::CPU => write!(f, "cpu"),
47 Device::CUDA(idx) => write!(f, "cuda:{}", idx),
48 Device::ROCm(idx) => write!(f, "rocm:{}", idx),
49 #[cfg(any(target_os = "macos", target_os = "ios"))]
50 Device::Metal => write!(f, "metal"),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
57pub enum DataType {
58 FP32,
60 FP16,
62 BF16,
64 FP8,
66 INT32,
68 INT16,
70 INT8,
72 INT4,
74 UINT32,
76 UINT16,
78 UINT8,
80 UINT4,
82 BOOL,
84}
85
86impl DataType {
87 pub fn size_bytes(&self) -> usize {
89 match self {
90 DataType::FP32 | DataType::INT32 | DataType::UINT32 => 4,
91 DataType::FP16 | DataType::BF16 | DataType::INT16 | DataType::UINT16 => 2,
92 DataType::FP8 | DataType::INT8 | DataType::UINT8 | DataType::BOOL => 1,
93 DataType::INT4 | DataType::UINT4 => 1, }
95 }
96
97 pub fn is_float(&self) -> bool {
99 matches!(
100 self,
101 DataType::FP32 | DataType::FP16 | DataType::BF16 | DataType::FP8
102 )
103 }
104
105 pub fn is_integer(&self) -> bool {
107 matches!(
108 self,
109 DataType::INT32
110 | DataType::INT16
111 | DataType::INT8
112 | DataType::INT4
113 | DataType::UINT32
114 | DataType::UINT16
115 | DataType::UINT8
116 | DataType::UINT4
117 )
118 }
119
120 pub fn is_quantized(&self) -> bool {
122 matches!(
123 self,
124 DataType::FP8 | DataType::INT8 | DataType::INT4 | DataType::UINT4
125 )
126 }
127}
128
129impl std::fmt::Display for DataType {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 let name = match self {
132 DataType::FP32 => "fp32",
133 DataType::FP16 => "fp16",
134 DataType::BF16 => "bf16",
135 DataType::FP8 => "fp8",
136 DataType::INT32 => "int32",
137 DataType::INT16 => "int16",
138 DataType::INT8 => "int8",
139 DataType::INT4 => "int4",
140 DataType::UINT32 => "uint32",
141 DataType::UINT16 => "uint16",
142 DataType::UINT8 => "uint8",
143 DataType::UINT4 => "uint4",
144 DataType::BOOL => "bool",
145 };
146 write!(f, "{}", name)
147 }
148}