use crate::backend::{Backend, DeviceCapabilities};
use crate::error::AphelionResult;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum CubeclDevice {
#[default]
Cpu,
Cuda(u32),
Metal(u32),
Vulkan(u32),
Wgpu(u32),
}
impl CubeclDevice {
pub fn as_label(&self) -> String {
match self {
CubeclDevice::Cpu => "cpu".to_string(),
CubeclDevice::Cuda(id) => format!("cuda:{}", id),
CubeclDevice::Metal(id) => format!("metal:{}", id),
CubeclDevice::Vulkan(id) => format!("vulkan:{}", id),
CubeclDevice::Wgpu(id) => format!("wgpu:{}", id),
}
}
pub fn is_cpu(&self) -> bool {
matches!(self, CubeclDevice::Cpu)
}
pub fn is_gpu(&self) -> bool {
!self.is_cpu()
}
pub fn is_cuda(&self) -> bool {
matches!(self, CubeclDevice::Cuda(_))
}
pub fn is_metal(&self) -> bool {
matches!(self, CubeclDevice::Metal(_))
}
pub fn is_vulkan(&self) -> bool {
matches!(self, CubeclDevice::Vulkan(_))
}
pub fn is_wgpu(&self) -> bool {
matches!(self, CubeclDevice::Wgpu(_))
}
}
#[derive(Debug, Clone)]
pub struct CubeclBackendConfig {
pub device: CubeclDevice,
pub memory_fraction: f32,
}
impl Default for CubeclBackendConfig {
fn default() -> Self {
Self {
device: CubeclDevice::Cpu,
memory_fraction: 0.9,
}
}
}
impl CubeclBackendConfig {
pub fn new(device: CubeclDevice) -> Self {
Self {
device,
memory_fraction: 0.9,
}
}
pub fn cpu() -> Self {
Self::new(CubeclDevice::Cpu)
}
pub fn cuda(device_id: u32) -> Self {
Self::new(CubeclDevice::Cuda(device_id))
}
pub fn metal(device_id: u32) -> Self {
Self::new(CubeclDevice::Metal(device_id))
}
pub fn vulkan(device_id: u32) -> Self {
Self::new(CubeclDevice::Vulkan(device_id))
}
pub fn wgpu(device_id: u32) -> Self {
Self::new(CubeclDevice::Wgpu(device_id))
}
pub fn with_memory_fraction(mut self, fraction: f32) -> Self {
self.memory_fraction = fraction.clamp(0.0, 1.0);
self
}
}
#[derive(Debug)]
pub struct CubeclBackend {
config: CubeclBackendConfig,
initialized: Arc<AtomicBool>,
}
impl Clone for CubeclBackend {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
initialized: Arc::new(AtomicBool::new(self.initialized.load(Ordering::SeqCst))),
}
}
}
impl CubeclBackend {
pub fn new(config: CubeclBackendConfig) -> Self {
Self {
config,
initialized: Arc::new(AtomicBool::new(false)),
}
}
pub fn cpu() -> Self {
Self::new(CubeclBackendConfig::cpu())
}
pub fn cuda(device_id: u32) -> Self {
Self::new(CubeclBackendConfig::cuda(device_id))
}
pub fn metal(device_id: u32) -> Self {
Self::new(CubeclBackendConfig::metal(device_id))
}
pub fn vulkan(device_id: u32) -> Self {
Self::new(CubeclBackendConfig::vulkan(device_id))
}
pub fn wgpu(device_id: u32) -> Self {
Self::new(CubeclBackendConfig::wgpu(device_id))
}
pub fn config(&self) -> &CubeclBackendConfig {
&self.config
}
pub fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::SeqCst)
}
fn check_device_availability(&self) -> bool {
match &self.config.device {
CubeclDevice::Cpu => true,
CubeclDevice::Cuda(_) => {
false
}
CubeclDevice::Metal(_) => {
cfg!(target_os = "macos")
}
CubeclDevice::Vulkan(_) => {
false
}
CubeclDevice::Wgpu(_) => {
false
}
}
}
}
impl Default for CubeclBackend {
fn default() -> Self {
Self::cpu()
}
}
impl Backend for CubeclBackend {
fn name(&self) -> &str {
"cubecl"
}
fn device(&self) -> &str {
match &self.config.device {
CubeclDevice::Cpu => "cpu",
CubeclDevice::Cuda(_) => "cuda",
CubeclDevice::Metal(_) => "metal",
CubeclDevice::Vulkan(_) => "vulkan",
CubeclDevice::Wgpu(_) => "wgpu",
}
}
fn capabilities(&self) -> DeviceCapabilities {
match &self.config.device {
CubeclDevice::Cpu => DeviceCapabilities {
supports_f16: false,
supports_bf16: false,
supports_tf32: false,
max_memory_bytes: None,
compute_units: None,
},
CubeclDevice::Cuda(_) => DeviceCapabilities {
supports_f16: true,
supports_bf16: true,
supports_tf32: true,
max_memory_bytes: Some(8 * 1024 * 1024 * 1024), compute_units: Some(128), },
CubeclDevice::Metal(_) => DeviceCapabilities {
supports_f16: true,
supports_bf16: false,
supports_tf32: false,
max_memory_bytes: None,
compute_units: None,
},
CubeclDevice::Vulkan(_) => DeviceCapabilities {
supports_f16: true,
supports_bf16: false,
supports_tf32: false,
max_memory_bytes: None,
compute_units: None,
},
CubeclDevice::Wgpu(_) => DeviceCapabilities {
supports_f16: true, supports_bf16: false,
supports_tf32: false,
max_memory_bytes: None,
compute_units: None,
},
}
}
fn is_available(&self) -> bool {
self.check_device_availability()
}
fn initialize(&mut self) -> AphelionResult<()> {
if self.initialized.load(Ordering::SeqCst) {
return Ok(());
}
if !self.check_device_availability() {
return Err(crate::error::AphelionError::backend(format!(
"CubeCL device {} is not available",
self.config.device.as_label()
)));
}
self.initialized.store(true, Ordering::SeqCst);
tracing::info!(
backend = "cubecl",
device = %self.config.device.as_label(),
memory_fraction = self.config.memory_fraction,
"CubeCL backend initialized (placeholder)"
);
Ok(())
}
fn shutdown(&mut self) -> AphelionResult<()> {
if !self.initialized.load(Ordering::SeqCst) {
return Ok(());
}
self.initialized.store(false, Ordering::SeqCst);
tracing::info!(
backend = "cubecl",
device = %self.config.device.as_label(),
"CubeCL backend shutdown (placeholder)"
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cubecl_device_as_label() {
assert_eq!(CubeclDevice::Cpu.as_label(), "cpu");
assert_eq!(CubeclDevice::Cuda(0).as_label(), "cuda:0");
assert_eq!(CubeclDevice::Cuda(1).as_label(), "cuda:1");
assert_eq!(CubeclDevice::Metal(0).as_label(), "metal:0");
assert_eq!(CubeclDevice::Vulkan(2).as_label(), "vulkan:2");
assert_eq!(CubeclDevice::Wgpu(0).as_label(), "wgpu:0");
}
#[test]
fn test_cubecl_device_default() {
let device = CubeclDevice::default();
assert_eq!(device, CubeclDevice::Cpu);
}
#[test]
fn test_cubecl_device_is_cpu_gpu() {
assert!(CubeclDevice::Cpu.is_cpu());
assert!(!CubeclDevice::Cpu.is_gpu());
assert!(!CubeclDevice::Cuda(0).is_cpu());
assert!(CubeclDevice::Cuda(0).is_gpu());
assert!(!CubeclDevice::Metal(0).is_cpu());
assert!(CubeclDevice::Metal(0).is_gpu());
assert!(!CubeclDevice::Vulkan(0).is_cpu());
assert!(CubeclDevice::Vulkan(0).is_gpu());
assert!(!CubeclDevice::Wgpu(0).is_cpu());
assert!(CubeclDevice::Wgpu(0).is_gpu());
}
#[test]
fn test_cubecl_device_type_checks() {
assert!(CubeclDevice::Cuda(0).is_cuda());
assert!(!CubeclDevice::Metal(0).is_cuda());
assert!(CubeclDevice::Metal(0).is_metal());
assert!(!CubeclDevice::Cuda(0).is_metal());
assert!(CubeclDevice::Vulkan(0).is_vulkan());
assert!(!CubeclDevice::Metal(0).is_vulkan());
assert!(CubeclDevice::Wgpu(0).is_wgpu());
assert!(!CubeclDevice::Vulkan(0).is_wgpu());
}
#[test]
fn test_cubecl_device_clone() {
let device1 = CubeclDevice::Cuda(1);
let device2 = device1.clone();
assert_eq!(device1, device2);
}
#[test]
fn test_cubecl_backend_config_default() {
let config = CubeclBackendConfig::default();
assert_eq!(config.device, CubeclDevice::Cpu);
assert!((config.memory_fraction - 0.9).abs() < f32::EPSILON);
}
#[test]
fn test_cubecl_backend_config_new() {
let config = CubeclBackendConfig::new(CubeclDevice::Cuda(1));
assert_eq!(config.device, CubeclDevice::Cuda(1));
assert!((config.memory_fraction - 0.9).abs() < f32::EPSILON);
}
#[test]
fn test_cubecl_backend_config_cpu() {
let config = CubeclBackendConfig::cpu();
assert!(config.device.is_cpu());
}
#[test]
fn test_cubecl_backend_config_cuda() {
let config = CubeclBackendConfig::cuda(0);
assert_eq!(config.device, CubeclDevice::Cuda(0));
}
#[test]
fn test_cubecl_backend_config_metal() {
let config = CubeclBackendConfig::metal(0);
assert_eq!(config.device, CubeclDevice::Metal(0));
}
#[test]
fn test_cubecl_backend_config_vulkan() {
let config = CubeclBackendConfig::vulkan(0);
assert_eq!(config.device, CubeclDevice::Vulkan(0));
}
#[test]
fn test_cubecl_backend_config_wgpu() {
let config = CubeclBackendConfig::wgpu(0);
assert_eq!(config.device, CubeclDevice::Wgpu(0));
}
#[test]
fn test_cubecl_backend_config_with_memory_fraction() {
let config = CubeclBackendConfig::cuda(0).with_memory_fraction(0.5);
assert!((config.memory_fraction - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_cubecl_backend_config_memory_fraction_clamped() {
let config_low = CubeclBackendConfig::cuda(0).with_memory_fraction(-0.5);
assert!((config_low.memory_fraction - 0.0).abs() < f32::EPSILON);
let config_high = CubeclBackendConfig::cuda(0).with_memory_fraction(1.5);
assert!((config_high.memory_fraction - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_cubecl_backend_config_builder_chain() {
let config = CubeclBackendConfig::cuda(0).with_memory_fraction(0.7);
assert_eq!(config.device, CubeclDevice::Cuda(0));
assert!((config.memory_fraction - 0.7).abs() < f32::EPSILON);
}
#[test]
fn test_cubecl_backend_new() {
let backend = CubeclBackend::new(CubeclBackendConfig::default());
assert_eq!(backend.name(), "cubecl");
assert_eq!(backend.device(), "cpu");
assert!(!backend.is_initialized());
}
#[test]
fn test_cubecl_backend_cpu() {
let backend = CubeclBackend::cpu();
assert_eq!(backend.device(), "cpu");
assert!(backend.config().device.is_cpu());
}
#[test]
fn test_cubecl_backend_cuda() {
let backend = CubeclBackend::cuda(0);
assert_eq!(backend.device(), "cuda");
assert!(backend.config().device.is_gpu());
}
#[test]
fn test_cubecl_backend_metal() {
let backend = CubeclBackend::metal(0);
assert_eq!(backend.device(), "metal");
assert!(backend.config().device.is_metal());
}
#[test]
fn test_cubecl_backend_vulkan() {
let backend = CubeclBackend::vulkan(0);
assert_eq!(backend.device(), "vulkan");
assert!(backend.config().device.is_vulkan());
}
#[test]
fn test_cubecl_backend_wgpu() {
let backend = CubeclBackend::wgpu(0);
assert_eq!(backend.device(), "wgpu");
assert!(backend.config().device.is_wgpu());
}
#[test]
fn test_cubecl_backend_default() {
let backend = CubeclBackend::default();
assert_eq!(backend.device(), "cpu");
}
#[test]
fn test_cubecl_backend_capabilities_cpu() {
let backend = CubeclBackend::cpu();
let caps = backend.capabilities();
assert!(!caps.supports_f16);
assert!(!caps.supports_bf16);
assert!(!caps.supports_tf32);
assert!(caps.max_memory_bytes.is_none());
}
#[test]
fn test_cubecl_backend_capabilities_cuda() {
let backend = CubeclBackend::cuda(0);
let caps = backend.capabilities();
assert!(caps.supports_f16);
assert!(caps.supports_bf16);
assert!(caps.supports_tf32);
assert!(caps.max_memory_bytes.is_some());
}
#[test]
fn test_cubecl_backend_capabilities_metal() {
let backend = CubeclBackend::metal(0);
let caps = backend.capabilities();
assert!(caps.supports_f16);
assert!(!caps.supports_bf16);
assert!(!caps.supports_tf32);
}
#[test]
fn test_cubecl_backend_capabilities_wgpu() {
let backend = CubeclBackend::wgpu(0);
let caps = backend.capabilities();
assert!(caps.supports_f16);
assert!(!caps.supports_bf16);
assert!(!caps.supports_tf32);
}
#[test]
fn test_cubecl_backend_is_available_cpu() {
let backend = CubeclBackend::cpu();
assert!(backend.is_available());
}
#[test]
fn test_cubecl_backend_is_available_cuda() {
let backend = CubeclBackend::cuda(0);
assert!(!backend.is_available());
}
#[test]
fn test_cubecl_backend_is_available_vulkan() {
let backend = CubeclBackend::vulkan(0);
assert!(!backend.is_available());
}
#[test]
fn test_cubecl_backend_is_available_wgpu() {
let backend = CubeclBackend::wgpu(0);
assert!(!backend.is_available());
}
#[test]
fn test_cubecl_backend_initialize_cpu() {
let mut backend = CubeclBackend::cpu();
assert!(!backend.is_initialized());
let result = backend.initialize();
assert!(result.is_ok());
assert!(backend.is_initialized());
let result = backend.initialize();
assert!(result.is_ok());
}
#[test]
fn test_cubecl_backend_initialize_cuda_fails() {
let mut backend = CubeclBackend::cuda(0);
let result = backend.initialize();
assert!(result.is_err());
}
#[test]
fn test_cubecl_backend_initialize_vulkan_fails() {
let mut backend = CubeclBackend::vulkan(0);
let result = backend.initialize();
assert!(result.is_err());
}
#[test]
fn test_cubecl_backend_initialize_wgpu_fails() {
let mut backend = CubeclBackend::wgpu(0);
let result = backend.initialize();
assert!(result.is_err());
}
#[test]
fn test_cubecl_backend_shutdown() {
let mut backend = CubeclBackend::cpu();
backend.initialize().unwrap();
assert!(backend.is_initialized());
let result = backend.shutdown();
assert!(result.is_ok());
assert!(!backend.is_initialized());
let result = backend.shutdown();
assert!(result.is_ok());
}
#[test]
fn test_cubecl_backend_shutdown_without_init() {
let mut backend = CubeclBackend::cpu();
assert!(!backend.is_initialized());
let result = backend.shutdown();
assert!(result.is_ok());
}
#[test]
fn test_cubecl_backend_clone() {
let mut backend = CubeclBackend::cpu();
backend.initialize().unwrap();
let cloned = backend.clone();
assert!(cloned.is_initialized());
assert_eq!(cloned.device(), backend.device());
}
#[test]
fn test_cubecl_backend_clone_preserves_config() {
let backend = CubeclBackend::new(CubeclBackendConfig::cuda(1).with_memory_fraction(0.5));
let cloned = backend.clone();
assert_eq!(cloned.config().device, CubeclDevice::Cuda(1));
assert!((cloned.config().memory_fraction - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_cubecl_backend_lifecycle() {
let mut backend = CubeclBackend::cpu();
assert!(!backend.is_initialized());
assert!(backend.is_available());
let init_result = backend.initialize();
assert!(init_result.is_ok());
assert!(backend.is_initialized());
let shutdown_result = backend.shutdown();
assert!(shutdown_result.is_ok());
assert!(!backend.is_initialized());
let reinit_result = backend.initialize();
assert!(reinit_result.is_ok());
assert!(backend.is_initialized());
}
#[test]
fn test_cubecl_backend_config_access() {
let backend = CubeclBackend::new(CubeclBackendConfig::cuda(2).with_memory_fraction(0.75));
let config = backend.config();
assert_eq!(config.device, CubeclDevice::Cuda(2));
assert!((config.memory_fraction - 0.75).abs() < f32::EPSILON);
}
}