use std::fmt;
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
use cubecl::Runtime;
#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
use cubecl::prelude::ComputeClient;
use ferrotorch_core::{FerrotorchError, FerrotorchResult};
#[cfg(feature = "cuda")]
use cubecl_cuda::{CudaDevice, CudaRuntime};
#[cfg(feature = "rocm")]
use cubecl_hip::{AmdDevice, HipRuntime};
#[cfg(feature = "wgpu")]
use cubecl_wgpu::{WgpuDevice, WgpuRuntime};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CubeDevice {
Cuda(usize),
Wgpu(usize),
Rocm(usize),
}
impl CubeDevice {
#[inline]
pub fn ordinal(&self) -> usize {
match self {
Self::Cuda(o) | Self::Wgpu(o) | Self::Rocm(o) => *o,
}
}
#[inline]
pub fn backend_name(&self) -> &'static str {
match self {
Self::Cuda(_) => "cuda",
Self::Wgpu(_) => "wgpu",
Self::Rocm(_) => "rocm",
}
}
}
impl fmt::Display for CubeDevice {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}", self.backend_name(), self.ordinal())
}
}
#[derive(Clone)]
pub enum CubeClient {
#[cfg(feature = "wgpu")]
Wgpu(ComputeClient<WgpuRuntime>),
#[cfg(feature = "cuda")]
Cuda(ComputeClient<CudaRuntime>),
#[cfg(feature = "rocm")]
Rocm(ComputeClient<HipRuntime>),
}
impl fmt::Debug for CubeClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
#[cfg(feature = "wgpu")]
Self::Wgpu(_) => f.write_str("CubeClient::Wgpu(..)"),
#[cfg(feature = "cuda")]
Self::Cuda(_) => f.write_str("CubeClient::Cuda(..)"),
#[cfg(feature = "rocm")]
Self::Rocm(_) => f.write_str("CubeClient::Rocm(..)"),
#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
_ => f.write_str("CubeClient::<no backend>"),
}
}
}
#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
const _: fn() = || {
let _ = std::mem::size_of::<CubeClient>();
};
#[derive(Clone, Debug)]
pub struct CubeRuntime {
device: CubeDevice,
client: CubeClient,
}
impl CubeRuntime {
pub fn new(device: CubeDevice) -> FerrotorchResult<Self> {
let client = Self::make_client(device)?;
Ok(Self { device, client })
}
#[inline]
pub fn device(&self) -> &CubeDevice {
&self.device
}
#[inline]
pub fn client(&self) -> &CubeClient {
&self.client
}
pub fn auto() -> Option<Self> {
#[cfg(feature = "cuda")]
{
return Self::new(CubeDevice::Cuda(0)).ok();
}
#[cfg(feature = "rocm")]
{
return Self::new(CubeDevice::Rocm(0)).ok();
}
#[cfg(feature = "wgpu")]
{
return Self::new(CubeDevice::Wgpu(0)).ok();
}
#[allow(unreachable_code)]
None
}
pub fn is_available() -> bool {
cfg!(any(feature = "cuda", feature = "rocm", feature = "wgpu"))
}
fn make_client(device: CubeDevice) -> FerrotorchResult<CubeClient> {
match device {
CubeDevice::Wgpu(idx) => {
#[cfg(feature = "wgpu")]
{
let wgpu_device = wgpu_device_for_index(idx);
let client = WgpuRuntime::client(&wgpu_device);
Ok(CubeClient::Wgpu(client))
}
#[cfg(not(feature = "wgpu"))]
{
let _ = idx;
Err(FerrotorchError::DeviceUnavailable)
}
}
CubeDevice::Cuda(idx) => {
#[cfg(feature = "cuda")]
{
let cuda_device = CudaDevice { index: idx };
let client = CudaRuntime::client(&cuda_device);
return Ok(CubeClient::Cuda(client));
}
#[cfg(not(feature = "cuda"))]
{
let _ = idx;
Err(FerrotorchError::DeviceUnavailable)
}
}
CubeDevice::Rocm(idx) => {
#[cfg(feature = "rocm")]
{
let amd_device = AmdDevice { index: idx };
let client = HipRuntime::client(&amd_device);
return Ok(CubeClient::Rocm(client));
}
#[cfg(not(feature = "rocm"))]
{
let _ = idx;
Err(FerrotorchError::DeviceUnavailable)
}
}
}
}
}
#[cfg(feature = "wgpu")]
fn wgpu_device_for_index(index: usize) -> WgpuDevice {
match index {
0 => WgpuDevice::DefaultDevice,
n => WgpuDevice::DiscreteGpu(n),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cube_device_ordinal() {
assert_eq!(CubeDevice::Cuda(3).ordinal(), 3);
assert_eq!(CubeDevice::Wgpu(1).ordinal(), 1);
assert_eq!(CubeDevice::Rocm(0).ordinal(), 0);
}
#[test]
fn cube_device_backend_name() {
assert_eq!(CubeDevice::Cuda(0).backend_name(), "cuda");
assert_eq!(CubeDevice::Wgpu(0).backend_name(), "wgpu");
assert_eq!(CubeDevice::Rocm(0).backend_name(), "rocm");
}
#[test]
fn cube_device_display() {
assert_eq!(CubeDevice::Cuda(2).to_string(), "cuda:2");
assert_eq!(CubeDevice::Wgpu(0).to_string(), "wgpu:0");
assert_eq!(CubeDevice::Rocm(1).to_string(), "rocm:1");
}
#[test]
fn cube_device_equality() {
assert_eq!(CubeDevice::Cuda(0), CubeDevice::Cuda(0));
assert_ne!(CubeDevice::Cuda(0), CubeDevice::Cuda(1));
assert_ne!(CubeDevice::Cuda(0), CubeDevice::Wgpu(0));
}
#[test]
fn cube_device_clone_and_hash() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(CubeDevice::Cuda(0));
set.insert(CubeDevice::Wgpu(0));
set.insert(CubeDevice::Rocm(0));
assert_eq!(set.len(), 3);
set.insert(CubeDevice::Cuda(0));
assert_eq!(set.len(), 3);
}
#[cfg(feature = "wgpu")]
#[test]
fn wgpu_runtime_new_and_device() {
let rt = CubeRuntime::new(CubeDevice::Wgpu(0)).expect("wgpu runtime init");
assert_eq!(*rt.device(), CubeDevice::Wgpu(0));
assert!(matches!(rt.client(), CubeClient::Wgpu(_)));
}
#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
#[test]
fn no_backend_feature_yields_device_unavailable() {
let err = CubeRuntime::new(CubeDevice::Wgpu(0)).unwrap_err();
assert!(matches!(err, FerrotorchError::DeviceUnavailable));
}
#[test]
fn cube_runtime_auto_returns_something_or_none() {
let result = CubeRuntime::auto();
if CubeRuntime::is_available() {
assert!(result.is_some());
} else {
assert!(result.is_none());
}
}
#[test]
fn cube_runtime_is_available_consistent() {
let available = CubeRuntime::is_available();
let auto = CubeRuntime::auto();
assert_eq!(available, auto.is_some());
}
}