1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
//! Device selection and GPU acceleration utilities
//!
//! Provides automatic device detection (CUDA/Metal/CPU) and device management
//! for efficient model training and inference on GPUs.
//!
//! # Features
//!
//! - **Auto-detection**: Automatically detects available CUDA/Metal devices
//! - **Fallback**: Gracefully falls back to CPU if GPU is unavailable
//! - **Memory Management**: Utilities for efficient GPU memory usage
//! - **Multi-GPU**: Support for selecting specific GPU devices
//!
//! # Examples
//!
//! ```rust
//! use kizzasi_core::device::{DeviceConfig, DeviceType, get_best_device};
//!
//! // Auto-select best available device
//! let device = get_best_device();
//!
//! // Or configure manually
//! let config = DeviceConfig::default()
//! .with_device_type(DeviceType::Cpu)
//! .with_device_id(0);
//! let device = config.create_device()?;
//! # Ok::<(), Box<dyn std::error::Error>>(())
//! ```
#[cfg(any(feature = "cuda", feature = "metal"))]
use crate::error::CoreError;
use crate::error::CoreResult;
use candle_core::Device;
use serde::{Deserialize, Serialize};
use std::fmt;
/// Device type for model execution
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DeviceType {
/// CPU execution (always available)
Cpu,
/// NVIDIA CUDA GPU (requires cuda feature)
#[cfg(feature = "cuda")]
Cuda,
/// Apple Metal GPU
#[cfg(feature = "metal")]
Metal,
}
impl fmt::Display for DeviceType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DeviceType::Cpu => write!(f, "CPU"),
#[cfg(feature = "cuda")]
DeviceType::Cuda => write!(f, "CUDA"),
#[cfg(feature = "metal")]
DeviceType::Metal => write!(f, "Metal"),
}
}
}
/// Device configuration for GPU acceleration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceConfig {
/// Device type to use
pub device_type: DeviceType,
/// Device ID (for multi-GPU systems)
pub device_id: usize,
/// Enable mixed precision (FP16)
pub use_fp16: bool,
/// Enable TF32 for matmul (CUDA only)
pub use_tf32: bool,
}
impl Default for DeviceConfig {
fn default() -> Self {
Self {
device_type: DeviceType::Cpu,
device_id: 0,
use_fp16: false,
use_tf32: false,
}
}
}
impl DeviceConfig {
/// Create a new device configuration
pub fn new() -> Self {
Self::default()
}
/// Set device type
pub fn with_device_type(mut self, device_type: DeviceType) -> Self {
self.device_type = device_type;
self
}
/// Set device ID
pub fn with_device_id(mut self, device_id: usize) -> Self {
self.device_id = device_id;
self
}
/// Enable FP16 precision
pub fn with_fp16(mut self, enabled: bool) -> Self {
self.use_fp16 = enabled;
self
}
/// Enable TF32 precision (CUDA only)
pub fn with_tf32(mut self, enabled: bool) -> Self {
self.use_tf32 = enabled;
self
}
/// Create a candle Device from this configuration
pub fn create_device(&self) -> CoreResult<Device> {
match self.device_type {
DeviceType::Cpu => Ok(Device::Cpu),
#[cfg(feature = "cuda")]
DeviceType::Cuda => {
#[cfg(any(target_os = "linux", target_os = "windows"))]
{
Device::new_cuda(self.device_id).map_err(|e| {
CoreError::DeviceError(format!(
"Failed to create CUDA device {}: {}",
self.device_id, e
))
})
}
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
{
Err(CoreError::DeviceError(
"CUDA is not supported on this platform (requires Linux or Windows)"
.to_string(),
))
}
}
#[cfg(feature = "metal")]
DeviceType::Metal => Device::new_metal(self.device_id).map_err(|e| {
CoreError::DeviceError(format!(
"Failed to create Metal device {}: {}",
self.device_id, e
))
}),
}
}
}
/// Check if CUDA is available
pub fn is_cuda_available() -> bool {
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
{
Device::new_cuda(0).is_ok()
}
#[cfg(not(all(feature = "cuda", any(target_os = "linux", target_os = "windows"))))]
{
false
}
}
/// Check if Metal is available
pub fn is_metal_available() -> bool {
#[cfg(feature = "metal")]
{
Device::new_metal(0).is_ok()
}
#[cfg(not(feature = "metal"))]
{
false
}
}
/// Get the best available device (CUDA > Metal > CPU)
pub fn get_best_device() -> Device {
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
{
if let Ok(device) = Device::new_cuda(0) {
tracing::info!("Using CUDA device 0");
return device;
}
}
#[cfg(feature = "metal")]
{
if let Ok(device) = Device::new_metal(0) {
tracing::info!("Using Metal device 0");
return device;
}
}
tracing::info!("Using CPU device");
Device::Cpu
}
/// Get available CUDA devices
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
pub fn get_cuda_devices() -> Vec<usize> {
let mut devices = Vec::new();
for id in 0..16 {
// Check up to 16 devices
if Device::new_cuda(id).is_ok() {
devices.push(id);
} else {
break;
}
}
devices
}
/// Get available Metal devices
#[cfg(feature = "metal")]
pub fn get_metal_devices() -> Vec<usize> {
let mut devices = Vec::new();
// Only check device 0 to avoid candle-core Metal backend panics with multiple devices
// See: https://github.com/huggingface/candle/issues (Metal backend has Vec index issues)
if Device::new_metal(0).is_ok() {
devices.push(0);
}
devices
}
/// Device information
#[derive(Debug, Clone)]
pub struct DeviceInfo {
/// Device type
pub device_type: DeviceType,
/// Device ID
pub device_id: usize,
/// Device name (if available)
pub name: Option<String>,
/// Total memory (bytes, if available)
pub total_memory: Option<u64>,
/// Available memory (bytes, if available)
pub available_memory: Option<u64>,
}
impl fmt::Display for DeviceInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} Device {}", self.device_type, self.device_id)?;
if let Some(name) = &self.name {
write!(f, " ({})", name)?;
}
if let Some(total) = self.total_memory {
write!(f, " - Total Memory: {} GB", total / (1024 * 1024 * 1024))?;
}
if let Some(available) = self.available_memory {
write!(f, " - Available: {} GB", available / (1024 * 1024 * 1024))?;
}
Ok(())
}
}
/// Get information about a device
pub fn get_device_info(device: &Device) -> DeviceInfo {
match device {
Device::Cpu => DeviceInfo {
device_type: DeviceType::Cpu,
device_id: 0,
name: Some("CPU".to_string()),
total_memory: None,
available_memory: None,
},
#[cfg(feature = "cuda")]
Device::Cuda(_cuda_device) => {
// Note: CudaDevice no longer has ordinal() method in candle-core 0.9.1
// Using 0 as default device ID. For actual device ID, would need to track
// it separately or use CUDA runtime API directly.
DeviceInfo {
device_type: DeviceType::Cuda,
device_id: 0,
name: None, // Could query via CUDA API
total_memory: None, // Could query via CUDA API
available_memory: None, // Could query via CUDA API
}
}
#[cfg(feature = "metal")]
Device::Metal(_metal_device) => {
DeviceInfo {
device_type: DeviceType::Metal,
device_id: 0, // Metal devices are numbered sequentially
name: None, // Could query via Metal API
total_memory: None, // Could query via Metal API
available_memory: None, // Could query via Metal API
}
}
// Catch-all for unhandled device variants (e.g., Metal when only cuda feature is enabled)
// This is needed because candle_core::Device always has all variants regardless of features
#[allow(unreachable_patterns)]
_ => DeviceInfo {
device_type: DeviceType::Cpu,
device_id: 0,
name: Some("Unknown".to_string()),
total_memory: None,
available_memory: None,
},
}
}
/// List all available devices
pub fn list_devices() -> Vec<DeviceInfo> {
#[allow(unused_mut)]
let mut result = vec![DeviceInfo {
device_type: DeviceType::Cpu,
device_id: 0,
name: Some("CPU".to_string()),
total_memory: None,
available_memory: None,
}];
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
{
for id in get_cuda_devices() {
if let Ok(device) = Device::new_cuda(id) {
result.push(get_device_info(&device));
}
}
}
#[cfg(feature = "metal")]
{
for id in get_metal_devices() {
if let Ok(device) = Device::new_metal(id) {
result.push(get_device_info(&device));
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_config_default() {
let config = DeviceConfig::default();
assert_eq!(config.device_type, DeviceType::Cpu);
assert_eq!(config.device_id, 0);
assert!(!config.use_fp16);
assert!(!config.use_tf32);
}
#[test]
fn test_device_config_builder() {
let config = DeviceConfig::new()
.with_device_id(1)
.with_fp16(true)
.with_tf32(true);
assert_eq!(config.device_id, 1);
assert!(config.use_fp16);
assert!(config.use_tf32);
}
#[test]
fn test_cpu_device_creation() {
let config = DeviceConfig::new();
let device = config.create_device().unwrap();
assert!(matches!(device, Device::Cpu));
}
#[test]
fn test_get_best_device() {
let device = get_best_device();
// Should always succeed - just check that we got a valid device
// (Could be CPU, CUDA, or Metal depending on features/hardware)
let _ = device; // Valid device was created
}
#[test]
fn test_list_devices() {
let devices = list_devices();
// Should always have at least CPU
assert!(!devices.is_empty());
assert_eq!(devices[0].device_type, DeviceType::Cpu);
}
#[test]
fn test_device_info_display() {
let info = DeviceInfo {
device_type: DeviceType::Cpu,
device_id: 0,
name: Some("Test CPU".to_string()),
total_memory: Some(16 * 1024 * 1024 * 1024), // 16 GB
available_memory: Some(8 * 1024 * 1024 * 1024), // 8 GB
};
let display = format!("{}", info);
assert!(display.contains("CPU"));
assert!(display.contains("Test CPU"));
assert!(display.contains("16 GB"));
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
#[test]
fn test_cuda_available() {
// Just test that the function doesn't panic
let _ = is_cuda_available();
}
#[cfg(feature = "metal")]
#[test]
fn test_metal_available() {
// Just test that the function doesn't panic
let _ = is_metal_available();
}
}