use core::fmt;
use sysinfo::System;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum Device {
#[default]
Cpu,
#[cfg(feature = "cuda")]
Cuda(usize),
#[cfg(feature = "vulkan")]
Vulkan(usize),
#[cfg(feature = "metal")]
Metal(usize),
#[cfg(feature = "wgpu")]
Wgpu(usize),
}
impl Device {
#[must_use]
pub fn is_available(self) -> bool {
match self {
Self::Cpu => true,
#[cfg(feature = "cuda")]
Self::Cuda(idx) => crate::backends::cuda::is_device_available(idx),
#[cfg(feature = "vulkan")]
Self::Vulkan(idx) => crate::backends::vulkan::is_device_available(idx),
#[cfg(feature = "metal")]
Self::Metal(idx) => crate::backends::metal::is_device_available(idx),
#[cfg(feature = "wgpu")]
Self::Wgpu(idx) => crate::backends::wgpu_backend::is_device_available(idx),
}
}
#[must_use]
pub const fn is_cpu(self) -> bool {
matches!(self, Self::Cpu)
}
#[must_use]
pub const fn is_gpu(self) -> bool {
!self.is_cpu()
}
#[must_use]
pub const fn index(self) -> usize {
match self {
Self::Cpu => 0,
#[cfg(feature = "cuda")]
Self::Cuda(idx) => idx,
#[cfg(feature = "vulkan")]
Self::Vulkan(idx) => idx,
#[cfg(feature = "metal")]
Self::Metal(idx) => idx,
#[cfg(feature = "wgpu")]
Self::Wgpu(idx) => idx,
}
}
#[must_use]
pub const fn device_type(self) -> &'static str {
match self {
Self::Cpu => "cpu",
#[cfg(feature = "cuda")]
Self::Cuda(_) => "cuda",
#[cfg(feature = "vulkan")]
Self::Vulkan(_) => "vulkan",
#[cfg(feature = "metal")]
Self::Metal(_) => "metal",
#[cfg(feature = "wgpu")]
Self::Wgpu(_) => "wgpu",
}
}
#[must_use]
pub const fn cpu() -> Self {
Self::Cpu
}
#[cfg(feature = "cuda")]
#[must_use]
pub const fn cuda(index: usize) -> Self {
Self::Cuda(index)
}
}
impl fmt::Display for Device {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Cpu => write!(f, "cpu"),
#[cfg(feature = "cuda")]
Self::Cuda(idx) => write!(f, "cuda:{idx}"),
#[cfg(feature = "vulkan")]
Self::Vulkan(idx) => write!(f, "vulkan:{idx}"),
#[cfg(feature = "metal")]
Self::Metal(idx) => write!(f, "metal:{idx}"),
#[cfg(feature = "wgpu")]
Self::Wgpu(idx) => write!(f, "wgpu:{idx}"),
}
}
}
#[derive(Debug, Clone)]
pub struct DeviceCapabilities {
pub name: String,
pub total_memory: usize,
pub available_memory: usize,
pub supports_f16: bool,
pub supports_f64: bool,
pub max_threads_per_block: usize,
pub compute_capability: Option<(usize, usize)>,
}
impl Device {
#[must_use]
pub fn capabilities(self) -> DeviceCapabilities {
match self {
Self::Cpu => DeviceCapabilities {
name: "CPU".to_string(),
total_memory: get_system_memory(),
available_memory: get_available_memory(),
supports_f16: true,
supports_f64: true,
max_threads_per_block: num_cpus(),
compute_capability: None,
},
#[cfg(feature = "cuda")]
Self::Cuda(idx) => crate::backends::cuda::get_capabilities(idx),
#[cfg(feature = "vulkan")]
Self::Vulkan(idx) => crate::backends::vulkan::get_capabilities(idx),
#[cfg(feature = "metal")]
Self::Metal(idx) => crate::backends::metal::get_capabilities(idx),
#[cfg(feature = "wgpu")]
Self::Wgpu(idx) => crate::backends::wgpu_backend::get_capabilities(idx),
}
}
}
fn get_system_memory() -> usize {
let sys = System::new_all();
sys.total_memory() as usize
}
fn get_available_memory() -> usize {
let sys = System::new_all();
sys.available_memory() as usize
}
fn num_cpus() -> usize {
std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get)
}
impl DeviceCapabilities {
#[must_use]
pub const fn supports_f32(&self) -> bool {
true }
}
#[cfg(feature = "cuda")]
#[must_use]
pub fn cuda_device_count() -> usize {
crate::backends::cuda::device_count()
}
#[cfg(feature = "vulkan")]
#[must_use]
pub fn vulkan_device_count() -> usize {
crate::backends::vulkan::device_count()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cpu_device() {
let device = Device::Cpu;
assert!(device.is_cpu());
assert!(!device.is_gpu());
assert!(device.is_available());
assert_eq!(device.device_type(), "cpu");
}
#[test]
fn test_device_display() {
let cpu = Device::Cpu;
assert_eq!(format!("{cpu}"), "cpu");
}
#[test]
fn test_device_default() {
let device = Device::default();
assert_eq!(device, Device::Cpu);
}
#[test]
fn test_device_capabilities() {
let caps = Device::Cpu.capabilities();
assert_eq!(caps.name, "CPU");
assert!(caps.supports_f32());
}
}