use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Device {
CPU,
CUDA(usize),
ROCm(usize),
#[cfg(any(target_os = "macos", target_os = "ios"))]
Metal,
}
impl Device {
pub fn is_gpu(&self) -> bool {
matches!(self, Device::CUDA(_) | Device::ROCm(_)) || {
#[cfg(any(target_os = "macos", target_os = "ios"))]
{
matches!(self, Device::Metal)
}
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
{
false
}
}
}
pub fn index(&self) -> Option<usize> {
match self {
Device::CUDA(idx) | Device::ROCm(idx) => Some(*idx),
_ => None,
}
}
}
impl std::fmt::Display for Device {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Device::CPU => write!(f, "cpu"),
Device::CUDA(idx) => write!(f, "cuda:{}", idx),
Device::ROCm(idx) => write!(f, "rocm:{}", idx),
#[cfg(any(target_os = "macos", target_os = "ios"))]
Device::Metal => write!(f, "metal"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DataType {
FP32,
FP16,
BF16,
FP8,
INT32,
INT16,
INT8,
INT4,
UINT32,
UINT16,
UINT8,
UINT4,
BOOL,
}
impl DataType {
pub fn size_bytes(&self) -> usize {
match self {
DataType::FP32 | DataType::INT32 | DataType::UINT32 => 4,
DataType::FP16 | DataType::BF16 | DataType::INT16 | DataType::UINT16 => 2,
DataType::FP8 | DataType::INT8 | DataType::UINT8 | DataType::BOOL => 1,
DataType::INT4 | DataType::UINT4 => 1, }
}
pub fn is_float(&self) -> bool {
matches!(
self,
DataType::FP32 | DataType::FP16 | DataType::BF16 | DataType::FP8
)
}
pub fn is_integer(&self) -> bool {
matches!(
self,
DataType::INT32
| DataType::INT16
| DataType::INT8
| DataType::INT4
| DataType::UINT32
| DataType::UINT16
| DataType::UINT8
| DataType::UINT4
)
}
pub fn is_quantized(&self) -> bool {
matches!(
self,
DataType::FP8 | DataType::INT8 | DataType::INT4 | DataType::UINT4
)
}
}
impl std::fmt::Display for DataType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = match self {
DataType::FP32 => "fp32",
DataType::FP16 => "fp16",
DataType::BF16 => "bf16",
DataType::FP8 => "fp8",
DataType::INT32 => "int32",
DataType::INT16 => "int16",
DataType::INT8 => "int8",
DataType::INT4 => "int4",
DataType::UINT32 => "uint32",
DataType::UINT16 => "uint16",
DataType::UINT8 => "uint8",
DataType::UINT4 => "uint4",
DataType::BOOL => "bool",
};
write!(f, "{}", name)
}
}