kizzasi-core 0.2.1

Core SSM (State Space Model) engine for Kizzasi AGSP
Documentation
//! Device selection and GPU acceleration utilities
//!
//! Provides automatic device detection (CUDA/Metal/CPU) and device management
//! for efficient model training and inference on GPUs.
//!
//! # Features
//!
//! - **Auto-detection**: Automatically detects available CUDA/Metal devices
//! - **Fallback**: Gracefully falls back to CPU if GPU is unavailable
//! - **Memory Management**: Utilities for efficient GPU memory usage
//! - **Multi-GPU**: Support for selecting specific GPU devices
//!
//! # Examples
//!
//! ```rust
//! use kizzasi_core::device::{DeviceConfig, DeviceType, get_best_device};
//!
//! // Auto-select best available device
//! let device = get_best_device();
//!
//! // Or configure manually
//! let config = DeviceConfig::default()
//!     .with_device_type(DeviceType::Cpu)
//!     .with_device_id(0);
//! let device = config.create_device()?;
//! # Ok::<(), Box<dyn std::error::Error>>(())
//! ```

#[cfg(any(feature = "cuda", feature = "metal"))]
use crate::error::CoreError;
use crate::error::CoreResult;
use candle_core::Device;
use serde::{Deserialize, Serialize};
use std::fmt;

/// Device type for model execution
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DeviceType {
    /// CPU execution (always available)
    Cpu,
    /// NVIDIA CUDA GPU (Linux/Windows only)
    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
    Cuda,
    /// Apple Metal GPU
    #[cfg(feature = "metal")]
    Metal,
}

impl fmt::Display for DeviceType {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            DeviceType::Cpu => write!(f, "CPU"),
            #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
            DeviceType::Cuda => write!(f, "CUDA"),
            #[cfg(feature = "metal")]
            DeviceType::Metal => write!(f, "Metal"),
        }
    }
}

/// Device configuration for GPU acceleration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceConfig {
    /// Device type to use
    pub device_type: DeviceType,
    /// Device ID (for multi-GPU systems)
    pub device_id: usize,
    /// Enable mixed precision (FP16)
    pub use_fp16: bool,
    /// Enable TF32 for matmul (CUDA only)
    pub use_tf32: bool,
}

impl Default for DeviceConfig {
    fn default() -> Self {
        Self {
            device_type: DeviceType::Cpu,
            device_id: 0,
            use_fp16: false,
            use_tf32: false,
        }
    }
}

impl DeviceConfig {
    /// Create a new device configuration
    pub fn new() -> Self {
        Self::default()
    }

    /// Set device type
    pub fn with_device_type(mut self, device_type: DeviceType) -> Self {
        self.device_type = device_type;
        self
    }

    /// Set device ID
    pub fn with_device_id(mut self, device_id: usize) -> Self {
        self.device_id = device_id;
        self
    }

    /// Enable FP16 precision
    pub fn with_fp16(mut self, enabled: bool) -> Self {
        self.use_fp16 = enabled;
        self
    }

    /// Enable TF32 precision (CUDA only)
    pub fn with_tf32(mut self, enabled: bool) -> Self {
        self.use_tf32 = enabled;
        self
    }

    /// Create a candle Device from this configuration
    pub fn create_device(&self) -> CoreResult<Device> {
        match self.device_type {
            DeviceType::Cpu => Ok(Device::Cpu),

            #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
            DeviceType::Cuda => Device::new_cuda(self.device_id).map_err(|e| {
                CoreError::DeviceError(format!(
                    "Failed to create CUDA device {}: {}",
                    self.device_id, e
                ))
            }),

            #[cfg(feature = "metal")]
            DeviceType::Metal => Device::new_metal(self.device_id).map_err(|e| {
                CoreError::DeviceError(format!(
                    "Failed to create Metal device {}: {}",
                    self.device_id, e
                ))
            }),
        }
    }
}

/// Check if CUDA is available
pub fn is_cuda_available() -> bool {
    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
    {
        Device::new_cuda(0).is_ok()
    }
    #[cfg(not(feature = "cuda"))]
    {
        false
    }
}

/// Check if Metal is available
pub fn is_metal_available() -> bool {
    #[cfg(feature = "metal")]
    {
        Device::new_metal(0).is_ok()
    }
    #[cfg(not(feature = "metal"))]
    {
        false
    }
}

/// Get the best available device (CUDA > Metal > CPU)
pub fn get_best_device() -> Device {
    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
    {
        if let Ok(device) = Device::new_cuda(0) {
            tracing::info!("Using CUDA device 0");
            return device;
        }
    }

    #[cfg(feature = "metal")]
    {
        if let Ok(device) = Device::new_metal(0) {
            tracing::info!("Using Metal device 0");
            return device;
        }
    }

    tracing::info!("Using CPU device");
    Device::Cpu
}

/// Get available CUDA devices
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
pub fn get_cuda_devices() -> Vec<usize> {
    let mut devices = Vec::new();
    for id in 0..16 {
        // Check up to 16 devices
        if Device::new_cuda(id).is_ok() {
            devices.push(id);
        } else {
            break;
        }
    }
    devices
}

/// Get available Metal devices
#[cfg(feature = "metal")]
pub fn get_metal_devices() -> Vec<usize> {
    let mut devices = Vec::new();
    // Only check device 0 to avoid candle-core Metal backend panics with multiple devices
    // See: https://github.com/huggingface/candle/issues (Metal backend has Vec index issues)
    if Device::new_metal(0).is_ok() {
        devices.push(0);
    }
    devices
}

/// Device information
#[derive(Debug, Clone)]
pub struct DeviceInfo {
    /// Device type
    pub device_type: DeviceType,
    /// Device ID
    pub device_id: usize,
    /// Device name (if available)
    pub name: Option<String>,
    /// Total memory (bytes, if available)
    pub total_memory: Option<u64>,
    /// Available memory (bytes, if available)
    pub available_memory: Option<u64>,
}

impl fmt::Display for DeviceInfo {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{} Device {}", self.device_type, self.device_id)?;
        if let Some(name) = &self.name {
            write!(f, " ({})", name)?;
        }
        if let Some(total) = self.total_memory {
            write!(f, " - Total Memory: {} GB", total / (1024 * 1024 * 1024))?;
        }
        if let Some(available) = self.available_memory {
            write!(f, " - Available: {} GB", available / (1024 * 1024 * 1024))?;
        }
        Ok(())
    }
}

/// Get information about a device
pub fn get_device_info(device: &Device) -> DeviceInfo {
    match device {
        Device::Cpu => DeviceInfo {
            device_type: DeviceType::Cpu,
            device_id: 0,
            name: Some("CPU".to_string()),
            total_memory: None,
            available_memory: None,
        },

        #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
        Device::Cuda(cuda_device) => {
            DeviceInfo {
                device_type: DeviceType::Cuda,
                device_id: cuda_device.ordinal(),
                name: None,             // Could query via CUDA API
                total_memory: None,     // Could query via CUDA API
                available_memory: None, // Could query via CUDA API
            }
        }

        #[cfg(feature = "metal")]
        Device::Metal(_metal_device) => {
            DeviceInfo {
                device_type: DeviceType::Metal,
                device_id: 0,           // Metal devices are numbered sequentially
                name: None,             // Could query via Metal API
                total_memory: None,     // Could query via Metal API
                available_memory: None, // Could query via Metal API
            }
        }

        _ => DeviceInfo {
            device_type: DeviceType::Cpu,
            device_id: 0,
            name: Some("Unknown".to_string()),
            total_memory: None,
            available_memory: None,
        },
    }
}

/// List all available devices
pub fn list_devices() -> Vec<DeviceInfo> {
    #[allow(unused_mut)]
    let mut result = vec![DeviceInfo {
        device_type: DeviceType::Cpu,
        device_id: 0,
        name: Some("CPU".to_string()),
        total_memory: None,
        available_memory: None,
    }];

    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
    {
        for id in get_cuda_devices() {
            if let Ok(device) = Device::new_cuda(id) {
                result.push(get_device_info(&device));
            }
        }
    }

    #[cfg(feature = "metal")]
    {
        for id in get_metal_devices() {
            if let Ok(device) = Device::new_metal(id) {
                result.push(get_device_info(&device));
            }
        }
    }

    result
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_device_config_default() {
        let config = DeviceConfig::default();
        assert_eq!(config.device_type, DeviceType::Cpu);
        assert_eq!(config.device_id, 0);
        assert!(!config.use_fp16);
        assert!(!config.use_tf32);
    }

    #[test]
    fn test_device_config_builder() {
        let config = DeviceConfig::new()
            .with_device_id(1)
            .with_fp16(true)
            .with_tf32(true);

        assert_eq!(config.device_id, 1);
        assert!(config.use_fp16);
        assert!(config.use_tf32);
    }

    #[test]
    fn test_cpu_device_creation() {
        let config = DeviceConfig::new();
        let device = config.create_device().unwrap();
        assert!(matches!(device, Device::Cpu));
    }

    #[test]
    fn test_get_best_device() {
        let device = get_best_device();
        // Should always succeed - just check that we got a valid device
        // (Could be CPU, CUDA, or Metal depending on features/hardware)
        let _ = device; // Valid device was created
    }

    #[test]
    fn test_list_devices() {
        let devices = list_devices();
        // Should always have at least CPU
        assert!(!devices.is_empty());
        assert_eq!(devices[0].device_type, DeviceType::Cpu);
    }

    #[test]
    fn test_device_info_display() {
        let info = DeviceInfo {
            device_type: DeviceType::Cpu,
            device_id: 0,
            name: Some("Test CPU".to_string()),
            total_memory: Some(16 * 1024 * 1024 * 1024), // 16 GB
            available_memory: Some(8 * 1024 * 1024 * 1024), // 8 GB
        };
        let display = format!("{}", info);
        assert!(display.contains("CPU"));
        assert!(display.contains("Test CPU"));
        assert!(display.contains("16 GB"));
    }

    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
    #[test]
    fn test_cuda_available() {
        // Just test that the function doesn't panic
        let _ = is_cuda_available();
    }

    #[cfg(feature = "metal")]
    #[test]
    fn test_metal_available() {
        // Just test that the function doesn't panic
        let _ = is_metal_available();
    }
}