1#[cfg(any(feature = "cuda", feature = "metal"))]
30use crate::error::CoreError;
31use crate::error::CoreResult;
32use candle_core::Device;
33use serde::{Deserialize, Serialize};
34use std::fmt;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38pub enum DeviceType {
39 Cpu,
41 #[cfg(feature = "cuda")]
43 Cuda,
44 #[cfg(feature = "metal")]
46 Metal,
47}
48
49impl fmt::Display for DeviceType {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 match self {
52 DeviceType::Cpu => write!(f, "CPU"),
53 #[cfg(feature = "cuda")]
54 DeviceType::Cuda => write!(f, "CUDA"),
55 #[cfg(feature = "metal")]
56 DeviceType::Metal => write!(f, "Metal"),
57 }
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct DeviceConfig {
64 pub device_type: DeviceType,
66 pub device_id: usize,
68 pub use_fp16: bool,
70 pub use_tf32: bool,
72}
73
74impl Default for DeviceConfig {
75 fn default() -> Self {
76 Self {
77 device_type: DeviceType::Cpu,
78 device_id: 0,
79 use_fp16: false,
80 use_tf32: false,
81 }
82 }
83}
84
85impl DeviceConfig {
86 pub fn new() -> Self {
88 Self::default()
89 }
90
91 pub fn with_device_type(mut self, device_type: DeviceType) -> Self {
93 self.device_type = device_type;
94 self
95 }
96
97 pub fn with_device_id(mut self, device_id: usize) -> Self {
99 self.device_id = device_id;
100 self
101 }
102
103 pub fn with_fp16(mut self, enabled: bool) -> Self {
105 self.use_fp16 = enabled;
106 self
107 }
108
109 pub fn with_tf32(mut self, enabled: bool) -> Self {
111 self.use_tf32 = enabled;
112 self
113 }
114
115 pub fn create_device(&self) -> CoreResult<Device> {
117 match self.device_type {
118 DeviceType::Cpu => Ok(Device::Cpu),
119
120 #[cfg(feature = "cuda")]
121 DeviceType::Cuda => {
122 #[cfg(any(target_os = "linux", target_os = "windows"))]
123 {
124 Device::new_cuda(self.device_id).map_err(|e| {
125 CoreError::DeviceError(format!(
126 "Failed to create CUDA device {}: {}",
127 self.device_id, e
128 ))
129 })
130 }
131 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
132 {
133 Err(CoreError::DeviceError(
134 "CUDA is not supported on this platform (requires Linux or Windows)"
135 .to_string(),
136 ))
137 }
138 }
139
140 #[cfg(feature = "metal")]
141 DeviceType::Metal => Device::new_metal(self.device_id).map_err(|e| {
142 CoreError::DeviceError(format!(
143 "Failed to create Metal device {}: {}",
144 self.device_id, e
145 ))
146 }),
147 }
148 }
149}
150
151pub fn is_cuda_available() -> bool {
153 #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
154 {
155 Device::new_cuda(0).is_ok()
156 }
157 #[cfg(not(all(feature = "cuda", any(target_os = "linux", target_os = "windows"))))]
158 {
159 false
160 }
161}
162
163pub fn is_metal_available() -> bool {
165 #[cfg(feature = "metal")]
166 {
167 Device::new_metal(0).is_ok()
168 }
169 #[cfg(not(feature = "metal"))]
170 {
171 false
172 }
173}
174
175pub fn get_best_device() -> Device {
177 #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
178 {
179 if let Ok(device) = Device::new_cuda(0) {
180 tracing::info!("Using CUDA device 0");
181 return device;
182 }
183 }
184
185 #[cfg(feature = "metal")]
186 {
187 if let Ok(device) = Device::new_metal(0) {
188 tracing::info!("Using Metal device 0");
189 return device;
190 }
191 }
192
193 tracing::info!("Using CPU device");
194 Device::Cpu
195}
196
197#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
199pub fn get_cuda_devices() -> Vec<usize> {
200 let mut devices = Vec::new();
201 for id in 0..16 {
202 if Device::new_cuda(id).is_ok() {
204 devices.push(id);
205 } else {
206 break;
207 }
208 }
209 devices
210}
211
212#[cfg(feature = "metal")]
214pub fn get_metal_devices() -> Vec<usize> {
215 let mut devices = Vec::new();
216 if Device::new_metal(0).is_ok() {
219 devices.push(0);
220 }
221 devices
222}
223
224#[derive(Debug, Clone)]
226pub struct DeviceInfo {
227 pub device_type: DeviceType,
229 pub device_id: usize,
231 pub name: Option<String>,
233 pub total_memory: Option<u64>,
235 pub available_memory: Option<u64>,
237}
238
239impl fmt::Display for DeviceInfo {
240 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
241 write!(f, "{} Device {}", self.device_type, self.device_id)?;
242 if let Some(name) = &self.name {
243 write!(f, " ({})", name)?;
244 }
245 if let Some(total) = self.total_memory {
246 write!(f, " - Total Memory: {} GB", total / (1024 * 1024 * 1024))?;
247 }
248 if let Some(available) = self.available_memory {
249 write!(f, " - Available: {} GB", available / (1024 * 1024 * 1024))?;
250 }
251 Ok(())
252 }
253}
254
255pub fn get_device_info(device: &Device) -> DeviceInfo {
257 match device {
258 Device::Cpu => DeviceInfo {
259 device_type: DeviceType::Cpu,
260 device_id: 0,
261 name: Some("CPU".to_string()),
262 total_memory: None,
263 available_memory: None,
264 },
265
266 #[cfg(feature = "cuda")]
267 Device::Cuda(_cuda_device) => {
268 DeviceInfo {
272 device_type: DeviceType::Cuda,
273 device_id: 0,
274 name: None, total_memory: None, available_memory: None, }
278 }
279
280 #[cfg(feature = "metal")]
281 Device::Metal(_metal_device) => {
282 DeviceInfo {
283 device_type: DeviceType::Metal,
284 device_id: 0, name: None, total_memory: None, available_memory: None, }
289 }
290
291 #[allow(unreachable_patterns)]
294 _ => DeviceInfo {
295 device_type: DeviceType::Cpu,
296 device_id: 0,
297 name: Some("Unknown".to_string()),
298 total_memory: None,
299 available_memory: None,
300 },
301 }
302}
303
304pub fn list_devices() -> Vec<DeviceInfo> {
306 #[allow(unused_mut)]
307 let mut result = vec![DeviceInfo {
308 device_type: DeviceType::Cpu,
309 device_id: 0,
310 name: Some("CPU".to_string()),
311 total_memory: None,
312 available_memory: None,
313 }];
314
315 #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
316 {
317 for id in get_cuda_devices() {
318 if let Ok(device) = Device::new_cuda(id) {
319 result.push(get_device_info(&device));
320 }
321 }
322 }
323
324 #[cfg(feature = "metal")]
325 {
326 for id in get_metal_devices() {
327 if let Ok(device) = Device::new_metal(id) {
328 result.push(get_device_info(&device));
329 }
330 }
331 }
332
333 result
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_device_config_default() {
342 let config = DeviceConfig::default();
343 assert_eq!(config.device_type, DeviceType::Cpu);
344 assert_eq!(config.device_id, 0);
345 assert!(!config.use_fp16);
346 assert!(!config.use_tf32);
347 }
348
349 #[test]
350 fn test_device_config_builder() {
351 let config = DeviceConfig::new()
352 .with_device_id(1)
353 .with_fp16(true)
354 .with_tf32(true);
355
356 assert_eq!(config.device_id, 1);
357 assert!(config.use_fp16);
358 assert!(config.use_tf32);
359 }
360
361 #[test]
362 fn test_cpu_device_creation() {
363 let config = DeviceConfig::new();
364 let device = config.create_device().unwrap();
365 assert!(matches!(device, Device::Cpu));
366 }
367
368 #[test]
369 fn test_get_best_device() {
370 let device = get_best_device();
371 let _ = device; }
375
376 #[test]
377 fn test_list_devices() {
378 let devices = list_devices();
379 assert!(!devices.is_empty());
381 assert_eq!(devices[0].device_type, DeviceType::Cpu);
382 }
383
384 #[test]
385 fn test_device_info_display() {
386 let info = DeviceInfo {
387 device_type: DeviceType::Cpu,
388 device_id: 0,
389 name: Some("Test CPU".to_string()),
390 total_memory: Some(16 * 1024 * 1024 * 1024), available_memory: Some(8 * 1024 * 1024 * 1024), };
393 let display = format!("{}", info);
394 assert!(display.contains("CPU"));
395 assert!(display.contains("Test CPU"));
396 assert!(display.contains("16 GB"));
397 }
398
399 #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
400 #[test]
401 fn test_cuda_available() {
402 let _ = is_cuda_available();
404 }
405
406 #[cfg(feature = "metal")]
407 #[test]
408 fn test_metal_available() {
409 let _ = is_metal_available();
411 }
412}