//! Device selection and GPU acceleration utilities
//!
//! Provides automatic device detection (CUDA/Metal/CPU) and device management
//! for efficient model training and inference on GPUs.
//!
//! # Features
//!
//! - **Auto-detection**: Automatically detects available CUDA/Metal devices
//! - **Fallback**: Gracefully falls back to CPU if GPU is unavailable
//! - **Memory Management**: Utilities for efficient GPU memory usage
//! - **Multi-GPU**: Support for selecting specific GPU devices
//!
//! # Examples
//!
//! ```rust
//! use kizzasi_core::device::{DeviceConfig, DeviceType, get_best_device};
//!
//! // Auto-select best available device
//! let device = get_best_device();
//!
//! // Or configure manually
//! let config = DeviceConfig::default()
//! .with_device_type(DeviceType::Cpu)
//! .with_device_id(0);
//! let device = config.create_device()?;
//! # Ok::<(), Box<dyn std::error::Error>>(())
//! ```
#[cfg(any(feature = "cuda", feature = "metal"))]
use crate::error::CoreError;
use crate::error::CoreResult;
use candle_core::Device;
use serde::{Deserialize, Serialize};
use std::fmt;
/// Device type for model execution
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DeviceType {
/// CPU execution (always available)
Cpu,
/// NVIDIA CUDA GPU (Linux/Windows only)
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
Cuda,
/// Apple Metal GPU
#[cfg(feature = "metal")]
Metal,
}
impl fmt::Display for DeviceType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DeviceType::Cpu => write!(f, "CPU"),
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
DeviceType::Cuda => write!(f, "CUDA"),
#[cfg(feature = "metal")]
DeviceType::Metal => write!(f, "Metal"),
}
}
}
/// Device configuration for GPU acceleration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceConfig {
/// Device type to use
pub device_type: DeviceType,
/// Device ID (for multi-GPU systems)
pub device_id: usize,
/// Enable mixed precision (FP16)
pub use_fp16: bool,
/// Enable TF32 for matmul (CUDA only)
pub use_tf32: bool,
}
impl Default for DeviceConfig {
fn default() -> Self {
Self {
device_type: DeviceType::Cpu,
device_id: 0,
use_fp16: false,
use_tf32: false,
}
}
}
impl DeviceConfig {
/// Create a new device configuration
pub fn new() -> Self {
Self::default()
}
/// Set device type
pub fn with_device_type(mut self, device_type: DeviceType) -> Self {
self.device_type = device_type;
self
}
/// Set device ID
pub fn with_device_id(mut self, device_id: usize) -> Self {
self.device_id = device_id;
self
}
/// Enable FP16 precision
pub fn with_fp16(mut self, enabled: bool) -> Self {
self.use_fp16 = enabled;
self
}
/// Enable TF32 precision (CUDA only)
pub fn with_tf32(mut self, enabled: bool) -> Self {
self.use_tf32 = enabled;
self
}
/// Create a candle Device from this configuration
pub fn create_device(&self) -> CoreResult<Device> {
match self.device_type {
DeviceType::Cpu => Ok(Device::Cpu),
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
DeviceType::Cuda => Device::new_cuda(self.device_id).map_err(|e| {
CoreError::DeviceError(format!(
"Failed to create CUDA device {}: {}",
self.device_id, e
))
}),
#[cfg(feature = "metal")]
DeviceType::Metal => Device::new_metal(self.device_id).map_err(|e| {
CoreError::DeviceError(format!(
"Failed to create Metal device {}: {}",
self.device_id, e
))
}),
}
}
}
/// Check if CUDA is available
pub fn is_cuda_available() -> bool {
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
{
Device::new_cuda(0).is_ok()
}
#[cfg(not(feature = "cuda"))]
{
false
}
}
/// Check if Metal is available
pub fn is_metal_available() -> bool {
#[cfg(feature = "metal")]
{
Device::new_metal(0).is_ok()
}
#[cfg(not(feature = "metal"))]
{
false
}
}
/// Get the best available device (CUDA > Metal > CPU)
pub fn get_best_device() -> Device {
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
{
if let Ok(device) = Device::new_cuda(0) {
tracing::info!("Using CUDA device 0");
return device;
}
}
#[cfg(feature = "metal")]
{
if let Ok(device) = Device::new_metal(0) {
tracing::info!("Using Metal device 0");
return device;
}
}
tracing::info!("Using CPU device");
Device::Cpu
}
/// Get available CUDA devices
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
pub fn get_cuda_devices() -> Vec<usize> {
let mut devices = Vec::new();
for id in 0..16 {
// Check up to 16 devices
if Device::new_cuda(id).is_ok() {
devices.push(id);
} else {
break;
}
}
devices
}
/// Get available Metal devices
#[cfg(feature = "metal")]
pub fn get_metal_devices() -> Vec<usize> {
let mut devices = Vec::new();
// Only check device 0 to avoid candle-core Metal backend panics with multiple devices
// See: https://github.com/huggingface/candle/issues (Metal backend has Vec index issues)
if Device::new_metal(0).is_ok() {
devices.push(0);
}
devices
}
/// Device information
#[derive(Debug, Clone)]
pub struct DeviceInfo {
/// Device type
pub device_type: DeviceType,
/// Device ID
pub device_id: usize,
/// Device name (if available)
pub name: Option<String>,
/// Total memory (bytes, if available)
pub total_memory: Option<u64>,
/// Available memory (bytes, if available)
pub available_memory: Option<u64>,
}
impl fmt::Display for DeviceInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} Device {}", self.device_type, self.device_id)?;
if let Some(name) = &self.name {
write!(f, " ({})", name)?;
}
if let Some(total) = self.total_memory {
write!(f, " - Total Memory: {} GB", total / (1024 * 1024 * 1024))?;
}
if let Some(available) = self.available_memory {
write!(f, " - Available: {} GB", available / (1024 * 1024 * 1024))?;
}
Ok(())
}
}
/// Get information about a device
pub fn get_device_info(device: &Device) -> DeviceInfo {
match device {
Device::Cpu => DeviceInfo {
device_type: DeviceType::Cpu,
device_id: 0,
name: Some("CPU".to_string()),
total_memory: None,
available_memory: None,
},
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
Device::Cuda(cuda_device) => {
DeviceInfo {
device_type: DeviceType::Cuda,
device_id: cuda_device.ordinal(),
name: None, // Could query via CUDA API
total_memory: None, // Could query via CUDA API
available_memory: None, // Could query via CUDA API
}
}
#[cfg(feature = "metal")]
Device::Metal(_metal_device) => {
DeviceInfo {
device_type: DeviceType::Metal,
device_id: 0, // Metal devices are numbered sequentially
name: None, // Could query via Metal API
total_memory: None, // Could query via Metal API
available_memory: None, // Could query via Metal API
}
}
_ => DeviceInfo {
device_type: DeviceType::Cpu,
device_id: 0,
name: Some("Unknown".to_string()),
total_memory: None,
available_memory: None,
},
}
}
/// List all available devices
pub fn list_devices() -> Vec<DeviceInfo> {
#[allow(unused_mut)]
let mut result = vec![DeviceInfo {
device_type: DeviceType::Cpu,
device_id: 0,
name: Some("CPU".to_string()),
total_memory: None,
available_memory: None,
}];
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
{
for id in get_cuda_devices() {
if let Ok(device) = Device::new_cuda(id) {
result.push(get_device_info(&device));
}
}
}
#[cfg(feature = "metal")]
{
for id in get_metal_devices() {
if let Ok(device) = Device::new_metal(id) {
result.push(get_device_info(&device));
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_config_default() {
let config = DeviceConfig::default();
assert_eq!(config.device_type, DeviceType::Cpu);
assert_eq!(config.device_id, 0);
assert!(!config.use_fp16);
assert!(!config.use_tf32);
}
#[test]
fn test_device_config_builder() {
let config = DeviceConfig::new()
.with_device_id(1)
.with_fp16(true)
.with_tf32(true);
assert_eq!(config.device_id, 1);
assert!(config.use_fp16);
assert!(config.use_tf32);
}
#[test]
fn test_cpu_device_creation() {
let config = DeviceConfig::new();
let device = config.create_device().unwrap();
assert!(matches!(device, Device::Cpu));
}
#[test]
fn test_get_best_device() {
let device = get_best_device();
// Should always succeed - just check that we got a valid device
// (Could be CPU, CUDA, or Metal depending on features/hardware)
let _ = device; // Valid device was created
}
#[test]
fn test_list_devices() {
let devices = list_devices();
// Should always have at least CPU
assert!(!devices.is_empty());
assert_eq!(devices[0].device_type, DeviceType::Cpu);
}
#[test]
fn test_device_info_display() {
let info = DeviceInfo {
device_type: DeviceType::Cpu,
device_id: 0,
name: Some("Test CPU".to_string()),
total_memory: Some(16 * 1024 * 1024 * 1024), // 16 GB
available_memory: Some(8 * 1024 * 1024 * 1024), // 8 GB
};
let display = format!("{}", info);
assert!(display.contains("CPU"));
assert!(display.contains("Test CPU"));
assert!(display.contains("16 GB"));
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
#[test]
fn test_cuda_available() {
// Just test that the function doesn't panic
let _ = is_cuda_available();
}
#[cfg(feature = "metal")]
#[test]
fn test_metal_available() {
// Just test that the function doesn't panic
let _ = is_metal_available();
}
}