use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CubeDevice {
Cuda(usize),
Wgpu(usize),
Rocm(usize),
}
impl CubeDevice {
#[inline]
pub fn ordinal(&self) -> usize {
match self {
Self::Cuda(o) | Self::Wgpu(o) | Self::Rocm(o) => *o,
}
}
#[inline]
pub fn backend_name(&self) -> &'static str {
match self {
Self::Cuda(_) => "cuda",
Self::Wgpu(_) => "wgpu",
Self::Rocm(_) => "rocm",
}
}
}
impl fmt::Display for CubeDevice {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}", self.backend_name(), self.ordinal())
}
}
#[derive(Debug, Clone)]
pub struct CubeRuntime {
device: CubeDevice,
}
impl CubeRuntime {
pub fn new(device: CubeDevice) -> Self {
Self { device }
}
#[inline]
pub fn device(&self) -> &CubeDevice {
&self.device
}
pub fn auto() -> Option<Self> {
#[cfg(feature = "cuda")]
{
return Some(Self {
device: CubeDevice::Cuda(0),
});
}
#[cfg(feature = "rocm")]
{
return Some(Self {
device: CubeDevice::Rocm(0),
});
}
#[cfg(feature = "wgpu")]
{
return Some(Self {
device: CubeDevice::Wgpu(0),
});
}
#[allow(unreachable_code)]
None
}
pub fn is_available() -> bool {
cfg!(any(feature = "cuda", feature = "rocm", feature = "wgpu"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cube_device_ordinal() {
assert_eq!(CubeDevice::Cuda(3).ordinal(), 3);
assert_eq!(CubeDevice::Wgpu(1).ordinal(), 1);
assert_eq!(CubeDevice::Rocm(0).ordinal(), 0);
}
#[test]
fn cube_device_backend_name() {
assert_eq!(CubeDevice::Cuda(0).backend_name(), "cuda");
assert_eq!(CubeDevice::Wgpu(0).backend_name(), "wgpu");
assert_eq!(CubeDevice::Rocm(0).backend_name(), "rocm");
}
#[test]
fn cube_device_display() {
assert_eq!(CubeDevice::Cuda(2).to_string(), "cuda:2");
assert_eq!(CubeDevice::Wgpu(0).to_string(), "wgpu:0");
assert_eq!(CubeDevice::Rocm(1).to_string(), "rocm:1");
}
#[test]
fn cube_device_equality() {
assert_eq!(CubeDevice::Cuda(0), CubeDevice::Cuda(0));
assert_ne!(CubeDevice::Cuda(0), CubeDevice::Cuda(1));
assert_ne!(CubeDevice::Cuda(0), CubeDevice::Wgpu(0));
}
#[test]
fn cube_runtime_new_and_device() {
let rt = CubeRuntime::new(CubeDevice::Wgpu(0));
assert_eq!(*rt.device(), CubeDevice::Wgpu(0));
}
#[test]
fn cube_runtime_auto_returns_something_or_none() {
let result = CubeRuntime::auto();
if CubeRuntime::is_available() {
assert!(result.is_some());
} else {
assert!(result.is_none());
}
}
#[test]
fn cube_runtime_is_available_consistent() {
let available = CubeRuntime::is_available();
let auto = CubeRuntime::auto();
assert_eq!(available, auto.is_some());
}
#[test]
fn cube_device_clone_and_hash() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(CubeDevice::Cuda(0));
set.insert(CubeDevice::Wgpu(0));
set.insert(CubeDevice::Rocm(0));
assert_eq!(set.len(), 3);
set.insert(CubeDevice::Cuda(0));
assert_eq!(set.len(), 3);
}
}