use burn_tensor::backend::DeviceOps;
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum MlxDevice {
Cpu,
#[default]
Gpu,
}
impl MlxDevice {
pub fn to_mlx_device(&self) -> mlx_rs::Device {
match self {
MlxDevice::Cpu => mlx_rs::Device::cpu(),
MlxDevice::Gpu => mlx_rs::Device::gpu(),
}
}
}
impl fmt::Display for MlxDevice {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MlxDevice::Cpu => write!(f, "MLX CPU"),
MlxDevice::Gpu => write!(f, "MLX GPU"),
}
}
}
impl DeviceOps for MlxDevice {
fn id(&self) -> burn_tensor::backend::DeviceId {
match self {
MlxDevice::Cpu => burn_tensor::backend::DeviceId::new(0, 0),
MlxDevice::Gpu => burn_tensor::backend::DeviceId::new(1, 0),
}
}
}