kizzasi_core/
device.rs

1//! Device selection and GPU acceleration utilities
2//!
3//! Provides automatic device detection (CUDA/Metal/CPU) and device management
4//! for efficient model training and inference on GPUs.
5//!
6//! # Features
7//!
8//! - **Auto-detection**: Automatically detects available CUDA/Metal devices
9//! - **Fallback**: Gracefully falls back to CPU if GPU is unavailable
10//! - **Memory Management**: Utilities for efficient GPU memory usage
11//! - **Multi-GPU**: Support for selecting specific GPU devices
12//!
13//! # Examples
14//!
15//! ```rust
16//! use kizzasi_core::device::{DeviceConfig, DeviceType, get_best_device};
17//!
18//! // Auto-select best available device
19//! let device = get_best_device();
20//!
21//! // Or configure manually
22//! let config = DeviceConfig::default()
23//!     .with_device_type(DeviceType::Cpu)
24//!     .with_device_id(0);
25//! let device = config.create_device()?;
26//! # Ok::<(), Box<dyn std::error::Error>>(())
27//! ```
28
29#[cfg(any(feature = "cuda", feature = "metal"))]
30use crate::error::CoreError;
31use crate::error::CoreResult;
32use candle_core::Device;
33use serde::{Deserialize, Serialize};
34use std::fmt;
35
36/// Device type for model execution
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38pub enum DeviceType {
39    /// CPU execution (always available)
40    Cpu,
41    /// NVIDIA CUDA GPU (requires cuda feature)
42    #[cfg(feature = "cuda")]
43    Cuda,
44    /// Apple Metal GPU
45    #[cfg(feature = "metal")]
46    Metal,
47}
48
49impl fmt::Display for DeviceType {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        match self {
52            DeviceType::Cpu => write!(f, "CPU"),
53            #[cfg(feature = "cuda")]
54            DeviceType::Cuda => write!(f, "CUDA"),
55            #[cfg(feature = "metal")]
56            DeviceType::Metal => write!(f, "Metal"),
57        }
58    }
59}
60
61/// Device configuration for GPU acceleration
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct DeviceConfig {
64    /// Device type to use
65    pub device_type: DeviceType,
66    /// Device ID (for multi-GPU systems)
67    pub device_id: usize,
68    /// Enable mixed precision (FP16)
69    pub use_fp16: bool,
70    /// Enable TF32 for matmul (CUDA only)
71    pub use_tf32: bool,
72}
73
74impl Default for DeviceConfig {
75    fn default() -> Self {
76        Self {
77            device_type: DeviceType::Cpu,
78            device_id: 0,
79            use_fp16: false,
80            use_tf32: false,
81        }
82    }
83}
84
85impl DeviceConfig {
86    /// Create a new device configuration
87    pub fn new() -> Self {
88        Self::default()
89    }
90
91    /// Set device type
92    pub fn with_device_type(mut self, device_type: DeviceType) -> Self {
93        self.device_type = device_type;
94        self
95    }
96
97    /// Set device ID
98    pub fn with_device_id(mut self, device_id: usize) -> Self {
99        self.device_id = device_id;
100        self
101    }
102
103    /// Enable FP16 precision
104    pub fn with_fp16(mut self, enabled: bool) -> Self {
105        self.use_fp16 = enabled;
106        self
107    }
108
109    /// Enable TF32 precision (CUDA only)
110    pub fn with_tf32(mut self, enabled: bool) -> Self {
111        self.use_tf32 = enabled;
112        self
113    }
114
115    /// Create a candle Device from this configuration
116    pub fn create_device(&self) -> CoreResult<Device> {
117        match self.device_type {
118            DeviceType::Cpu => Ok(Device::Cpu),
119
120            #[cfg(feature = "cuda")]
121            DeviceType::Cuda => {
122                #[cfg(any(target_os = "linux", target_os = "windows"))]
123                {
124                    Device::new_cuda(self.device_id).map_err(|e| {
125                        CoreError::DeviceError(format!(
126                            "Failed to create CUDA device {}: {}",
127                            self.device_id, e
128                        ))
129                    })
130                }
131                #[cfg(not(any(target_os = "linux", target_os = "windows")))]
132                {
133                    Err(CoreError::DeviceError(
134                        "CUDA is not supported on this platform (requires Linux or Windows)"
135                            .to_string(),
136                    ))
137                }
138            }
139
140            #[cfg(feature = "metal")]
141            DeviceType::Metal => Device::new_metal(self.device_id).map_err(|e| {
142                CoreError::DeviceError(format!(
143                    "Failed to create Metal device {}: {}",
144                    self.device_id, e
145                ))
146            }),
147        }
148    }
149}
150
151/// Check if CUDA is available
152pub fn is_cuda_available() -> bool {
153    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
154    {
155        Device::new_cuda(0).is_ok()
156    }
157    #[cfg(not(all(feature = "cuda", any(target_os = "linux", target_os = "windows"))))]
158    {
159        false
160    }
161}
162
163/// Check if Metal is available
164pub fn is_metal_available() -> bool {
165    #[cfg(feature = "metal")]
166    {
167        Device::new_metal(0).is_ok()
168    }
169    #[cfg(not(feature = "metal"))]
170    {
171        false
172    }
173}
174
175/// Get the best available device (CUDA > Metal > CPU)
176pub fn get_best_device() -> Device {
177    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
178    {
179        if let Ok(device) = Device::new_cuda(0) {
180            tracing::info!("Using CUDA device 0");
181            return device;
182        }
183    }
184
185    #[cfg(feature = "metal")]
186    {
187        if let Ok(device) = Device::new_metal(0) {
188            tracing::info!("Using Metal device 0");
189            return device;
190        }
191    }
192
193    tracing::info!("Using CPU device");
194    Device::Cpu
195}
196
197/// Get available CUDA devices
198#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
199pub fn get_cuda_devices() -> Vec<usize> {
200    let mut devices = Vec::new();
201    for id in 0..16 {
202        // Check up to 16 devices
203        if Device::new_cuda(id).is_ok() {
204            devices.push(id);
205        } else {
206            break;
207        }
208    }
209    devices
210}
211
212/// Get available Metal devices
213#[cfg(feature = "metal")]
214pub fn get_metal_devices() -> Vec<usize> {
215    let mut devices = Vec::new();
216    // Only check device 0 to avoid candle-core Metal backend panics with multiple devices
217    // See: https://github.com/huggingface/candle/issues (Metal backend has Vec index issues)
218    if Device::new_metal(0).is_ok() {
219        devices.push(0);
220    }
221    devices
222}
223
224/// Device information
225#[derive(Debug, Clone)]
226pub struct DeviceInfo {
227    /// Device type
228    pub device_type: DeviceType,
229    /// Device ID
230    pub device_id: usize,
231    /// Device name (if available)
232    pub name: Option<String>,
233    /// Total memory (bytes, if available)
234    pub total_memory: Option<u64>,
235    /// Available memory (bytes, if available)
236    pub available_memory: Option<u64>,
237}
238
239impl fmt::Display for DeviceInfo {
240    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
241        write!(f, "{} Device {}", self.device_type, self.device_id)?;
242        if let Some(name) = &self.name {
243            write!(f, " ({})", name)?;
244        }
245        if let Some(total) = self.total_memory {
246            write!(f, " - Total Memory: {} GB", total / (1024 * 1024 * 1024))?;
247        }
248        if let Some(available) = self.available_memory {
249            write!(f, " - Available: {} GB", available / (1024 * 1024 * 1024))?;
250        }
251        Ok(())
252    }
253}
254
255/// Get information about a device
256pub fn get_device_info(device: &Device) -> DeviceInfo {
257    match device {
258        Device::Cpu => DeviceInfo {
259            device_type: DeviceType::Cpu,
260            device_id: 0,
261            name: Some("CPU".to_string()),
262            total_memory: None,
263            available_memory: None,
264        },
265
266        #[cfg(feature = "cuda")]
267        Device::Cuda(_cuda_device) => {
268            // Note: CudaDevice no longer has ordinal() method in candle-core 0.9.1
269            // Using 0 as default device ID. For actual device ID, would need to track
270            // it separately or use CUDA runtime API directly.
271            DeviceInfo {
272                device_type: DeviceType::Cuda,
273                device_id: 0,
274                name: None,             // Could query via CUDA API
275                total_memory: None,     // Could query via CUDA API
276                available_memory: None, // Could query via CUDA API
277            }
278        }
279
280        #[cfg(feature = "metal")]
281        Device::Metal(_metal_device) => {
282            DeviceInfo {
283                device_type: DeviceType::Metal,
284                device_id: 0,           // Metal devices are numbered sequentially
285                name: None,             // Could query via Metal API
286                total_memory: None,     // Could query via Metal API
287                available_memory: None, // Could query via Metal API
288            }
289        }
290
291        // Catch-all for unhandled device variants (e.g., Metal when only cuda feature is enabled)
292        // This is needed because candle_core::Device always has all variants regardless of features
293        #[allow(unreachable_patterns)]
294        _ => DeviceInfo {
295            device_type: DeviceType::Cpu,
296            device_id: 0,
297            name: Some("Unknown".to_string()),
298            total_memory: None,
299            available_memory: None,
300        },
301    }
302}
303
304/// List all available devices
305pub fn list_devices() -> Vec<DeviceInfo> {
306    #[allow(unused_mut)]
307    let mut result = vec![DeviceInfo {
308        device_type: DeviceType::Cpu,
309        device_id: 0,
310        name: Some("CPU".to_string()),
311        total_memory: None,
312        available_memory: None,
313    }];
314
315    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
316    {
317        for id in get_cuda_devices() {
318            if let Ok(device) = Device::new_cuda(id) {
319                result.push(get_device_info(&device));
320            }
321        }
322    }
323
324    #[cfg(feature = "metal")]
325    {
326        for id in get_metal_devices() {
327            if let Ok(device) = Device::new_metal(id) {
328                result.push(get_device_info(&device));
329            }
330        }
331    }
332
333    result
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn test_device_config_default() {
342        let config = DeviceConfig::default();
343        assert_eq!(config.device_type, DeviceType::Cpu);
344        assert_eq!(config.device_id, 0);
345        assert!(!config.use_fp16);
346        assert!(!config.use_tf32);
347    }
348
349    #[test]
350    fn test_device_config_builder() {
351        let config = DeviceConfig::new()
352            .with_device_id(1)
353            .with_fp16(true)
354            .with_tf32(true);
355
356        assert_eq!(config.device_id, 1);
357        assert!(config.use_fp16);
358        assert!(config.use_tf32);
359    }
360
361    #[test]
362    fn test_cpu_device_creation() {
363        let config = DeviceConfig::new();
364        let device = config.create_device().unwrap();
365        assert!(matches!(device, Device::Cpu));
366    }
367
368    #[test]
369    fn test_get_best_device() {
370        let device = get_best_device();
371        // Should always succeed - just check that we got a valid device
372        // (Could be CPU, CUDA, or Metal depending on features/hardware)
373        let _ = device; // Valid device was created
374    }
375
376    #[test]
377    fn test_list_devices() {
378        let devices = list_devices();
379        // Should always have at least CPU
380        assert!(!devices.is_empty());
381        assert_eq!(devices[0].device_type, DeviceType::Cpu);
382    }
383
384    #[test]
385    fn test_device_info_display() {
386        let info = DeviceInfo {
387            device_type: DeviceType::Cpu,
388            device_id: 0,
389            name: Some("Test CPU".to_string()),
390            total_memory: Some(16 * 1024 * 1024 * 1024), // 16 GB
391            available_memory: Some(8 * 1024 * 1024 * 1024), // 8 GB
392        };
393        let display = format!("{}", info);
394        assert!(display.contains("CPU"));
395        assert!(display.contains("Test CPU"));
396        assert!(display.contains("16 GB"));
397    }
398
399    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
400    #[test]
401    fn test_cuda_available() {
402        // Just test that the function doesn't panic
403        let _ = is_cuda_available();
404    }
405
406    #[cfg(feature = "metal")]
407    #[test]
408    fn test_metal_available() {
409        // Just test that the function doesn't panic
410        let _ = is_metal_available();
411    }
412}