use std::fmt::Display;
#[cfg(any(feature = "cuda", feature = "metal"))]
use thiserror::Error;
#[cfg(any(feature = "cuda", feature = "metal"))]
#[derive(Error, Debug)]
pub enum DeviceError {
#[error("Error from Candle: {0}")]
CandleError(#[from] candle_core::Error),
}
#[derive(Debug, PartialEq, Clone)]
enum DeviceTpye {
Cpu,
#[cfg(feature = "cuda")]
Cuda,
#[cfg(feature = "metal")]
Metal,
}
#[derive(Debug, Clone)]
pub struct Device {
inner: candle_core::Device,
device_type: DeviceTpye,
index: usize,
}
impl Device {
pub fn cpu() -> Self {
Self {
inner: candle_core::Device::Cpu,
device_type: DeviceTpye::Cpu,
index: 0,
}
}
#[cfg(feature = "cuda")]
pub fn cuda(index: usize) -> Result<Self, DeviceError> {
Ok(Self {
inner: candle_core::Device::new_cuda(index)?,
device_type: DeviceTpye::Cuda,
index,
})
}
#[cfg(feature = "cuda")]
pub fn cuda_if_available(index: usize) -> Self {
let (device, device_type, index) = match candle_core::Device::new_cuda(index) {
Ok(device) => (device, DeviceTpye::Cuda, index),
Err(_) => (candle_core::Device::Cpu, DeviceTpye::Cpu, 0),
};
Self {
inner: device,
device_type,
index,
}
}
#[cfg(feature = "metal")]
pub fn metal(index: usize) -> Result<Self, DeviceError> {
Ok(Self {
inner: candle_core::Device::new_metal(index)?,
device_type: DeviceTpye::Metal,
index,
})
}
#[cfg(feature = "metal")]
pub fn metal_if_available(index: usize) -> Self {
let (device, device_type, index) = match candle_core::Device::new_metal(index) {
Ok(device) => (device, DeviceTpye::Metal, index),
Err(_) => (candle_core::Device::Cpu, DeviceTpye::Cpu, 0),
};
Self {
inner: device,
device_type,
index,
}
}
}
impl std::ops::Deref for Device {
type Target = candle_core::Device;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl PartialEq for Device {
fn eq(&self, other: &Self) -> bool {
self.device_type == other.device_type && self.index == other.index
}
}
impl Eq for Device {}
impl Display for DeviceTpye {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DeviceTpye::Cpu => write!(f, "cpu"),
#[cfg(feature = "cuda")]
DeviceTpye::Cuda => write!(f, "cuda"),
#[cfg(feature = "metal")]
DeviceTpye::Metal => write!(f, "metal"),
}
}
}
impl Display for Device {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.device_type {
DeviceTpye::Cpu => write!(f, "{}", self.device_type),
#[cfg(feature = "cuda")]
DeviceTpye::Cuda => write!(f, "{}({})", self.device_type, self.index),
#[cfg(feature = "metal")]
DeviceTpye::Metal => write!(f, "{}({})", self.device_type, self.index),
}
}
}