use std::fmt;
pub trait Device: Clone + Send + Sync + 'static {
fn name(&self) -> &'static str;
fn is_cpu(&self) -> bool;
fn is_cuda(&self) -> bool;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Cpu;
impl Device for Cpu {
fn name(&self) -> &'static str {
"cpu"
}
fn is_cpu(&self) -> bool {
true
}
fn is_cuda(&self) -> bool {
false
}
}
impl fmt::Display for Cpu {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "cpu")
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Cuda {
pub device_id: i32,
}
impl Cuda {
pub fn new(device_id: i32) -> Self {
Cuda { device_id }
}
}
impl Device for Cuda {
fn name(&self) -> &'static str {
"cuda"
}
fn is_cpu(&self) -> bool {
false
}
fn is_cuda(&self) -> bool {
true
}
}
impl fmt::Display for Cuda {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "cuda:{}", self.device_id)
}
}