use std::fmt::Display;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum DeviceError {
#[error("Error from Candle: {0}")]
CandleError(#[from] candle_core::Error),
#[error("Other error: {0}")]
OtherError(String),
}
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum DeviceKind {
Cpu,
#[cfg(feature = "cuda")]
Cuda,
#[cfg(feature = "metal")]
Metal,
}
#[derive(Debug, Clone)]
pub struct Device {
inner: candle_core::Device,
kind: DeviceKind,
index: usize,
}
impl Device {
pub fn cpu() -> Self {
Self {
inner: candle_core::Device::Cpu,
kind: DeviceKind::Cpu,
index: 0,
}
}
#[cfg(feature = "cuda")]
pub fn cuda(index: usize) -> Result<Self, DeviceError> {
Ok(Self {
inner: candle_core::Device::new_cuda(index)?,
kind: DeviceKind::Cuda,
index,
})
}
#[cfg(feature = "cuda")]
pub fn cuda_if_available(index: usize) -> Self {
let (device, kind, index) = match candle_core::Device::new_cuda(index) {
Ok(device) => (device, DeviceKind::Cuda, index),
Err(_) => (candle_core::Device::Cpu, DeviceKind::Cpu, 0),
};
Self {
inner: device,
kind,
index,
}
}
#[cfg(feature = "metal")]
pub fn metal(index: usize) -> Result<Self, DeviceError> {
Ok(Self {
inner: candle_core::Device::new_metal(index)?,
kind: DeviceKind::Metal,
index,
})
}
#[cfg(feature = "metal")]
pub fn metal_if_available(index: usize) -> Self {
let (device, kind, index) = match candle_core::Device::new_metal(index) {
Ok(device) => (device, DeviceKind::Metal, index),
Err(_) => (candle_core::Device::Cpu, DeviceKind::Cpu, 0),
};
Self {
inner: device,
kind,
index,
}
}
pub fn kind(&self) -> DeviceKind {
self.kind
}
pub fn index(&self) -> usize {
self.index
}
pub fn is_cpu(&self) -> bool {
self.kind == DeviceKind::Cpu
}
#[cfg(feature = "cuda")]
pub fn is_cuda(&self) -> bool {
self.kind == DeviceKind::Cuda
}
#[cfg(feature = "metal")]
pub fn is_metal(&self) -> bool {
self.kind == DeviceKind::Metal
}
pub fn synchronize(&self) -> Result<(), DeviceError> {
self.inner.synchronize()?;
Ok(())
}
}
impl std::ops::Deref for Device {
type Target = candle_core::Device;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl PartialEq for Device {
fn eq(&self, other: &Self) -> bool {
self.kind == other.kind && self.index == other.index
}
}
impl Eq for Device {}
impl Display for DeviceKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DeviceKind::Cpu => write!(f, "cpu"),
#[cfg(feature = "cuda")]
DeviceKind::Cuda => write!(f, "cuda"),
#[cfg(feature = "metal")]
DeviceKind::Metal => write!(f, "metal"),
}
}
}
impl Display for Device {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.kind {
DeviceKind::Cpu => write!(f, "{}", self.kind),
#[cfg(feature = "cuda")]
DeviceKind::Cuda => write!(f, "{}({})", self.kind, self.index),
#[cfg(feature = "metal")]
DeviceKind::Metal => write!(f, "{}({})", self.kind, self.index),
}
}
}