Skip to main content

kapsl_hal/
device.rs

1use serde::{Deserialize, Serialize};
2use std::process::{Command, Stdio};
3use std::sync::OnceLock;
4use std::time::{Duration, Instant};
5use sys_info;
6use thiserror::Error;
7
8const DEFAULT_PROBE_TIMEOUT: Duration = Duration::from_millis(800);
9const SYSTEM_PROFILER_TIMEOUT: Duration = Duration::from_secs(3);
10const COMMAND_POLL_INTERVAL: Duration = Duration::from_millis(25);
11
12static PROBE_CACHE: OnceLock<DeviceInfo> = OnceLock::new();
13
14#[derive(Debug, Error)]
15pub enum DeviceProbeError {
16    #[error("sys_info error: {0}")]
17    SysInfo(#[from] sys_info::Error),
18}
19
20/// Backend/provider for a device.
21///
22/// The serialized form is always a lowercase string (e.g. "cuda").
23/// Unknown strings round-trip via `Custom`.
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25pub enum DeviceBackend {
26    Cpu,
27    Cuda,
28    Metal,
29    Rocm,
30    DirectML,
31    OpenCL,
32    Vulkan,
33    WebGpu,
34    OneApi,
35    Custom(String),
36}
37
38impl DeviceBackend {
39    fn parse(raw: &str) -> Self {
40        let trimmed = raw.trim();
41        if trimmed.is_empty() {
42            return Self::Custom(String::new());
43        }
44
45        match trimmed.to_ascii_lowercase().as_str() {
46            "cpu" => Self::Cpu,
47            "cuda" => Self::Cuda,
48            "metal" => Self::Metal,
49            "rocm" => Self::Rocm,
50            "directml" => Self::DirectML,
51            "opencl" => Self::OpenCL,
52            "vulkan" => Self::Vulkan,
53            "webgpu" => Self::WebGpu,
54            "oneapi" => Self::OneApi,
55            other => Self::Custom(other.to_string()),
56        }
57    }
58}
59
60impl std::fmt::Display for DeviceBackend {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        match self {
63            DeviceBackend::Cpu => write!(f, "cpu"),
64            DeviceBackend::Cuda => write!(f, "cuda"),
65            DeviceBackend::Metal => write!(f, "metal"),
66            DeviceBackend::Rocm => write!(f, "rocm"),
67            DeviceBackend::DirectML => write!(f, "directml"),
68            DeviceBackend::OpenCL => write!(f, "opencl"),
69            DeviceBackend::Vulkan => write!(f, "vulkan"),
70            DeviceBackend::WebGpu => write!(f, "webgpu"),
71            DeviceBackend::OneApi => write!(f, "oneapi"),
72            DeviceBackend::Custom(s) => write!(f, "{s}"),
73        }
74    }
75}
76
77impl Serialize for DeviceBackend {
78    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
79    where
80        S: serde::Serializer,
81    {
82        serializer.serialize_str(&self.to_string())
83    }
84}
85
86impl<'de> Deserialize<'de> for DeviceBackend {
87    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
88    where
89        D: serde::Deserializer<'de>,
90    {
91        let raw = String::deserialize(deserializer)?;
92        Ok(Self::parse(&raw))
93    }
94}
95
96// NOTE: Keep this struct as a simple JSON-friendly record: all optional fields are `Option<...>`.
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct Device {
99    pub id: usize,
100    pub name: String,
101    pub backend: DeviceBackend,
102
103    pub memory_mb: u64,
104    pub compute_units: u32,
105
106    pub pci_bus_id: Option<String>,
107
108    /// Stable partition identifier for sub-device addressing.
109    ///
110    /// For NVIDIA devices this is the GPU UUID (e.g.
111    /// `"GPU-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"`), which survives
112    /// index reordering. For MIG compute instances the MIG UUID
113    /// (e.g. `"MIG-GPU-xxx/0/0"`) is stored here instead.
114    ///
115    /// Matched by `GpuPreference::Partition` using the `mig:<id>` or
116    /// `partition:<id>` selector syntax. `None` for backends that do
117    /// not expose a stable UUID.
118    pub partition_id: Option<String>,
119
120    /// Driver version string (when available).
121    pub driver_version: Option<String>,
122
123    /// CUDA version string (e.g. "12.0") for CUDA-capable devices.
124    pub cuda_version: Option<String>,
125
126    /// CUDA compute capability (e.g. "8.6") for CUDA devices.
127    pub compute_capability: Option<String>,
128
129    pub utilization_gpu_pct: Option<u32>,
130    pub temperature_c: Option<u32>,
131
132    pub supports_fp16: bool,
133    pub supports_int8: bool,
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct DeviceInfo {
138    // CPU info
139    pub cpu_cores: u32,
140    pub total_memory: u64,
141    pub os_type: String,
142    pub os_release: String,
143
144    pub has_cuda: bool,
145    pub has_metal: bool,
146    pub has_rocm: bool,
147    pub has_directml: bool,
148
149    // All detected devices (CPU + GPUS)
150    pub devices: Vec<Device>,
151}
152
153#[derive(Debug, Clone, PartialEq, Eq)]
154pub enum GpuPreference {
155    /// Match by provider backend and device id, e.g. "cuda:1".
156    BackendId { backend: String, id: usize },
157    /// Match by PCI bus id, e.g. "00000000:01:00.0".
158    PciBusId(String),
159    /// Match by case-insensitive substring in `Device.name`.
160    NameContains(String),
161    /// Match by stable partition identifier (GPU UUID or MIG UUID).
162    ///
163    /// Parsed from `"mig:<id>"` or `"partition:<id>"` selector strings.
164    /// The stored value is the portion after the prefix, matched
165    /// case-insensitively against `Device.partition_id`.
166    Partition(String),
167}
168
169impl GpuPreference {
170    pub fn parse(spec: &str) -> Option<Self> {
171        let trimmed = spec.trim();
172        if trimmed.is_empty() {
173            return None;
174        }
175
176        let lowered = trimmed.to_ascii_lowercase();
177
178        // Partition prefix: "mig:<id>" or "partition:<id>".
179        // Preserve original case for UUID matching (UUIDs are typically mixed-case).
180        for prefix in &["mig:", "partition:"] {
181            if lowered.starts_with(prefix) {
182                let rest = trimmed[prefix.len()..].trim();
183                if !rest.is_empty() {
184                    return Some(Self::Partition(rest.to_string()));
185                }
186            }
187        }
188
189        if let Some((backend, id)) = lowered.split_once(':') {
190            if let Ok(id) = id.trim().parse::<usize>() {
191                return Some(Self::BackendId {
192                    backend: backend.trim().to_string(),
193                    id,
194                });
195            }
196        }
197
198        // Heuristic: if it contains ':' and '.', it looks like a PCI bus id.
199        if trimmed.contains(':') && trimmed.contains('.') {
200            return Some(Self::PciBusId(trimmed.to_string()));
201        }
202
203        Some(Self::NameContains(trimmed.to_ascii_lowercase()))
204    }
205
206    /// Returns `true` if this preference matches the given device.
207    pub fn matches(&self, device: &Device) -> bool {
208        match self {
209            Self::BackendId { backend, id } => {
210                device.backend.to_string().eq_ignore_ascii_case(backend) && device.id == *id
211            }
212            Self::PciBusId(bus_id) => device
213                .pci_bus_id
214                .as_deref()
215                .is_some_and(|v| v.eq_ignore_ascii_case(bus_id)),
216            Self::NameContains(needle) => {
217                let needle_lower = needle.to_ascii_lowercase();
218                device.name.to_ascii_lowercase().contains(&needle_lower)
219            }
220            Self::Partition(partition_id) => device
221                .partition_id
222                .as_deref()
223                .is_some_and(|v| v.eq_ignore_ascii_case(partition_id)),
224        }
225    }
226}
227
228impl DeviceInfo {
229    /// Probe device information (cached).
230    pub fn probe() -> Self {
231        if let Some(cached) = PROBE_CACHE.get() {
232            return cached.clone();
233        }
234
235        let probed = Self::try_probe_with_timeout(DEFAULT_PROBE_TIMEOUT).unwrap_or_else(|_| {
236            // Best-effort fallback; keeps the old API non-panicking.
237            Self::fallback()
238        });
239
240        let _ = PROBE_CACHE.set(probed.clone());
241        probed
242    }
243
244    /// Probe device information (not cached).
245    pub fn try_probe() -> Result<Self, DeviceProbeError> {
246        Self::try_probe_with_timeout(DEFAULT_PROBE_TIMEOUT)
247    }
248
249    /// Probe device information with a timeout applied to external commands.
250    pub fn try_probe_with_timeout(timeout: Duration) -> Result<Self, DeviceProbeError> {
251        let cpu_cores = sys_info::cpu_num()?;
252        let total_memory = sys_info::mem_info()?.total;
253        let os_type = sys_info::os_type().unwrap_or_else(|_| "unknown".to_string());
254        let os_release = sys_info::os_release().unwrap_or_else(|_| "unknown".to_string());
255
256        let mut devices = Vec::new();
257        devices.push(Device {
258            id: 0,
259            name: "CPU".to_string(),
260            backend: DeviceBackend::Cpu,
261            memory_mb: total_memory / 1024,
262            compute_units: cpu_cores,
263            pci_bus_id: None,
264            partition_id: None,
265            driver_version: None,
266            cuda_version: None,
267            compute_capability: None,
268            utilization_gpu_pct: None,
269            temperature_c: None,
270            supports_fp16: true,
271            supports_int8: true,
272        });
273
274        let cuda_version = Self::detect_cuda_version(timeout);
275
276        if let Some(nvml_devices) = Self::detect_cuda_gpus_nvml(cuda_version.as_deref()) {
277            devices.extend(nvml_devices);
278        } else {
279            devices.extend(Self::detect_cuda_gpus(timeout, cuda_version.as_deref()));
280        }
281
282        devices.extend(Self::detect_rocm_gpus(timeout, &os_release));
283        devices.extend(Self::detect_metal(SYSTEM_PROFILER_TIMEOUT, &os_release));
284        devices.extend(Self::detect_directml(timeout, &os_release));
285        devices.extend(Self::detect_oneapi(timeout));
286        devices.extend(Self::detect_webgpu());
287
288        let (has_cuda, has_metal, has_rocm, has_directml) = Self::provider_flags(&devices);
289
290        Ok(Self {
291            cpu_cores,
292            total_memory,
293            os_type,
294            os_release,
295            has_cuda,
296            has_metal,
297            has_rocm,
298            has_directml,
299            devices,
300        })
301    }
302
303    fn fallback() -> Self {
304        Self {
305            cpu_cores: 1,
306            total_memory: 0,
307            os_type: "unknown".to_string(),
308            os_release: "unknown".to_string(),
309            has_cuda: false,
310            has_metal: false,
311            has_rocm: false,
312            has_directml: false,
313            devices: vec![Device {
314                id: 0,
315                name: "CPU".to_string(),
316                backend: DeviceBackend::Cpu,
317                memory_mb: 0,
318                compute_units: 1,
319                pci_bus_id: None,
320                partition_id: None,
321                driver_version: None,
322                cuda_version: None,
323                compute_capability: None,
324                utilization_gpu_pct: None,
325                temperature_c: None,
326                supports_fp16: true,
327                supports_int8: true,
328            }],
329        }
330    }
331
332    fn provider_flags(devices: &[Device]) -> (bool, bool, bool, bool) {
333        let has_cuda = devices
334            .iter()
335            .any(|d| matches!(d.backend, DeviceBackend::Cuda));
336        let has_metal = devices
337            .iter()
338            .any(|d| matches!(d.backend, DeviceBackend::Metal));
339        let has_rocm = devices
340            .iter()
341            .any(|d| matches!(d.backend, DeviceBackend::Rocm));
342        let has_directml = devices
343            .iter()
344            .any(|d| matches!(d.backend, DeviceBackend::DirectML));
345        (has_cuda, has_metal, has_rocm, has_directml)
346    }
347
348    fn run_command_with_timeout(
349        program: &str,
350        args: &[&str],
351        timeout: Duration,
352    ) -> Option<Vec<u8>> {
353        let mut child = Command::new(program)
354            .args(args)
355            .stdout(Stdio::piped())
356            .stderr(Stdio::piped())
357            .spawn()
358            .ok()?;
359
360        let start = Instant::now();
361        loop {
362            if start.elapsed() >= timeout {
363                let _ = child.kill();
364                let _ = child.wait();
365                return None;
366            }
367
368            match child.try_wait().ok()? {
369                Some(_) => break,
370                None => std::thread::sleep(COMMAND_POLL_INTERVAL),
371            }
372        }
373
374        let out = child.wait_with_output().ok()?;
375        if !out.status.success() {
376            return None;
377        }
378        Some(out.stdout)
379    }
380
381    fn parse_cuda_version_from_smi_summary(stdout: &str) -> Option<String> {
382        // Example header line:
383        // | NVIDIA-SMI 535.54.03 Driver Version: 535.54.03 CUDA Version: 12.2 |
384        let needle = "CUDA Version:";
385        let idx = stdout.find(needle)?;
386        let rest = &stdout[idx + needle.len()..];
387        let version = rest
388            .trim_start()
389            .split(|c: char| c.is_whitespace() || c == '|')
390            .next()?
391            .trim();
392        if version.is_empty() {
393            None
394        } else {
395            Some(version.to_string())
396        }
397    }
398
399    fn detect_cuda_version(timeout: Duration) -> Option<String> {
400        let stdout = Self::run_command_with_timeout("nvidia-smi", &[], timeout)?;
401        let text = String::from_utf8_lossy(&stdout);
402        Self::parse_cuda_version_from_smi_summary(&text)
403    }
404
405    fn parse_compute_capability(value: &str) -> Option<(u32, u32)> {
406        let trimmed = value.trim();
407        if trimmed.is_empty() || trimmed.eq_ignore_ascii_case("n/a") {
408            return None;
409        }
410        let (major, minor) = trimmed.split_once('.')?;
411        Some((major.parse().ok()?, minor.parse().ok()?))
412    }
413
414    fn detect_cuda_gpus(timeout: Duration, cuda_version: Option<&str>) -> Vec<Device> {
415        let mut devices = Vec::new();
416
417        let query = "index,name,memory.total,utilization.gpu,temperature.gpu,pci.bus_id,driver_version,compute_cap,uuid";
418        let args = ["--query-gpu", query, "--format=csv,noheader,nounits"];
419
420        let stdout = match Self::run_command_with_timeout("nvidia-smi", &args, timeout) {
421            Some(s) => s,
422            None => return devices,
423        };
424
425        let text = String::from_utf8_lossy(&stdout);
426        for line in text.lines() {
427            let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
428            if parts.len() < 7 {
429                continue;
430            }
431
432            let id = parts[0].parse::<usize>().unwrap_or(0);
433            let name = parts.get(1).copied().unwrap_or("NVIDIA GPU").to_string();
434            let memory_mb = parts
435                .get(2)
436                .and_then(|v| v.parse::<u64>().ok())
437                .unwrap_or(0);
438            let utilization_gpu_pct = parts.get(3).and_then(|v| v.parse::<u32>().ok());
439            let temperature_c = parts.get(4).and_then(|v| v.parse::<u32>().ok());
440
441            let pci_bus_id = parts
442                .get(5)
443                .map(|v| v.trim())
444                .filter(|v| !v.is_empty() && !v.eq_ignore_ascii_case("n/a"))
445                .map(|v| v.to_string());
446
447            let driver_version = parts
448                .get(6)
449                .map(|v| v.trim())
450                .filter(|v| !v.is_empty() && !v.eq_ignore_ascii_case("n/a"))
451                .map(|v| v.to_string());
452
453            let compute_capability = parts
454                .get(7)
455                .map(|v| v.trim())
456                .filter(|v| !v.is_empty() && !v.eq_ignore_ascii_case("n/a"))
457                .map(|v| v.to_string());
458
459            // parts[8] = uuid (e.g. "GPU-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx").
460            // Used as the stable partition_id so placement can survive index reordering.
461            let partition_id = parts
462                .get(8)
463                .map(|v| v.trim())
464                .filter(|v| !v.is_empty() && !v.eq_ignore_ascii_case("n/a"))
465                .map(|v| v.to_string());
466
467            let (supports_fp16, supports_int8) = match compute_capability
468                .as_deref()
469                .and_then(Self::parse_compute_capability)
470            {
471                Some((major, _minor)) => (major >= 5, major >= 6),
472                None => (true, true),
473            };
474
475            devices.push(Device {
476                id,
477                name,
478                backend: DeviceBackend::Cuda,
479                memory_mb,
480                compute_units: 0,
481                pci_bus_id,
482                partition_id,
483                driver_version,
484                cuda_version: cuda_version.map(|s| s.to_string()),
485                compute_capability,
486                utilization_gpu_pct,
487                temperature_c,
488                supports_fp16,
489                supports_int8,
490            });
491        }
492
493        devices
494    }
495
496    fn detect_rocm_gpus(timeout: Duration, os_release: &str) -> Vec<Device> {
497        let mut devices = Vec::new();
498
499        let stdout = match Self::run_command_with_timeout("rocm-smi", &["-i"], timeout) {
500            Some(s) => s,
501            None => return devices,
502        };
503
504        let text = String::from_utf8_lossy(&stdout);
505
506        // Prefer bracket form: GPU[0]
507        let mut ids = Vec::new();
508        for line in text.lines() {
509            let line = line.trim();
510            if let Some(start) = line.find("GPU[") {
511                let rest = &line[start + 4..];
512                if let Some(end) = rest.find(']') {
513                    if let Ok(id) = rest[..end].parse::<usize>() {
514                        if !ids.contains(&id) {
515                            ids.push(id);
516                        }
517                    }
518                }
519            }
520        }
521
522        if ids.is_empty() {
523            // Fallback: table format where first token is GPU index.
524            for line in text.lines() {
525                let line = line.trim_start();
526                if line.is_empty() {
527                    continue;
528                }
529                let first = line.split_whitespace().next().unwrap_or("");
530                if let Ok(id) = first.parse::<usize>() {
531                    if !ids.contains(&id) {
532                        ids.push(id);
533                    }
534                }
535            }
536        }
537
538        ids.sort_unstable();
539        for id in ids {
540            devices.push(Device {
541                id,
542                name: format!("AMD ROCm GPU {id}"),
543                backend: DeviceBackend::Rocm,
544                memory_mb: 0,
545                compute_units: 0,
546                pci_bus_id: None,
547                partition_id: None,
548                driver_version: Some(os_release.to_string()),
549                cuda_version: None,
550                compute_capability: None,
551                utilization_gpu_pct: None,
552                temperature_c: None,
553                supports_fp16: true,
554                supports_int8: true,
555            });
556        }
557
558        devices
559    }
560
561    fn parse_memory_mb(value: &str) -> Option<u64> {
562        let lowered = value.trim().to_ascii_lowercase();
563        if lowered.is_empty() {
564            return None;
565        }
566
567        let mut number = String::new();
568        for ch in lowered.chars() {
569            if ch.is_ascii_digit() || ch == '.' {
570                number.push(ch);
571            } else if !number.is_empty() {
572                break;
573            }
574        }
575
576        let num: f64 = number.parse().ok()?;
577        if lowered.contains("gb") {
578            Some((num * 1024.0) as u64)
579        } else if lowered.contains("mb") {
580            Some(num as u64)
581        } else if lowered.contains("kb") {
582            Some((num / 1024.0) as u64)
583        } else {
584            None
585        }
586    }
587
588    fn detect_metal(timeout: Duration, os_release: &str) -> Vec<Device> {
589        #[cfg(target_os = "macos")]
590        {
591            let mut devs = Vec::new();
592            let stdout = match Self::run_command_with_timeout(
593                "system_profiler",
594                &["SPDisplaysDataType", "-json"],
595                timeout,
596            ) {
597                Some(s) => s,
598                None => return devs,
599            };
600
601            let value: serde_json::Value = match serde_json::from_slice(&stdout) {
602                Ok(v) => v,
603                Err(_) => return devs,
604            };
605
606            let displays = match value.get("SPDisplaysDataType").and_then(|v| v.as_array()) {
607                Some(v) => v,
608                None => return devs,
609            };
610
611            for (idx, item) in displays.iter().enumerate() {
612                let name = item
613                    .get("spdisplays_chipset_model")
614                    .and_then(|v| v.as_str())
615                    .or_else(|| item.get("sppci_model").and_then(|v| v.as_str()))
616                    .or_else(|| item.get("_name").and_then(|v| v.as_str()))
617                    .unwrap_or("Apple Metal GPU")
618                    .to_string();
619
620                let vram_text = item
621                    .get("spdisplays_vram")
622                    .and_then(|v| v.as_str())
623                    .or_else(|| item.get("spdisplays_vram_shared").and_then(|v| v.as_str()))
624                    .unwrap_or("");
625
626                let memory_mb = Self::parse_memory_mb(vram_text).unwrap_or(0);
627
628                devs.push(Device {
629                    id: idx,
630                    name,
631                    backend: DeviceBackend::Metal,
632                    memory_mb,
633                    compute_units: 0,
634                    pci_bus_id: None,
635                    partition_id: None,
636                    driver_version: Some(os_release.to_string()),
637                    cuda_version: None,
638                    compute_capability: None,
639                    utilization_gpu_pct: None,
640                    temperature_c: None,
641                    supports_fp16: true,
642                    supports_int8: true,
643                });
644            }
645
646            devs
647        }
648
649        #[cfg(not(target_os = "macos"))]
650        {
651            let _ = (timeout, os_release);
652            Vec::new()
653        }
654    }
655
656    fn detect_directml(_timeout: Duration, os_release: &str) -> Vec<Device> {
657        #[cfg(target_os = "windows")]
658        {
659            // Best-effort: DirectML runs on DX12 adapters; detailed enumeration would require
660            // Windows APIs. Provide a placeholder device for feature-gating higher layers.
661            vec![Device {
662                id: 0,
663                name: "DirectML GPU".into(),
664                backend: DeviceBackend::DirectML,
665                memory_mb: 0,
666                compute_units: 0,
667                pci_bus_id: None,
668                partition_id: None,
669                driver_version: Some(os_release.to_string()),
670                cuda_version: None,
671                compute_capability: None,
672                utilization_gpu_pct: None,
673                temperature_c: None,
674                supports_fp16: true,
675                supports_int8: true,
676            }]
677        }
678
679        #[cfg(not(target_os = "windows"))]
680        {
681            let _ = os_release;
682            Vec::new()
683        }
684    }
685
686    fn detect_oneapi(_timeout: Duration) -> Vec<Device> {
687        // OneAPI/Level-Zero enumeration is intentionally stubbed here to keep `kapsl-hal`
688        // dependency-light. Higher layers can treat the backend as available when they
689        // can actually create a OneAPI engine.
690        Vec::new()
691    }
692
693    fn detect_webgpu() -> Vec<Device> {
694        // WebGPU is only meaningful for wasm/browser builds.
695        Vec::new()
696    }
697
698    #[cfg(feature = "nvml")]
699    fn detect_cuda_gpus_nvml(cuda_version: Option<&str>) -> Option<Vec<Device>> {
700        use nvml_wrapper::Nvml;
701
702        let nvml = Nvml::init().ok()?;
703        let driver_version = nvml.sys_driver_version().ok();
704        let count = nvml.device_count().ok()?;
705
706        let mut devices = Vec::with_capacity(count as usize);
707        for index in 0..count {
708            let dev = nvml.device_by_index(index).ok()?;
709            let name = dev.name().ok().unwrap_or_else(|| "NVIDIA GPU".to_string());
710            let memory_mb = dev
711                .memory_info()
712                .ok()
713                .map(|m| m.total / (1024 * 1024))
714                .unwrap_or(0);
715            let pci_bus_id = dev
716                .pci_info()
717                .ok()
718                .map(|p| p.bus_id)
719                .filter(|s| !s.trim().is_empty());
720            let utilization_gpu_pct = dev.utilization_rates().ok().map(|u| u.gpu);
721            let temperature_c = dev
722                .temperature(nvml_wrapper::enum_wrappers::device::TemperatureSensor::Gpu)
723                .ok();
724            let cc = dev
725                .cuda_compute_capability()
726                .ok()
727                .map(|(maj, min)| format!("{maj}.{min}"));
728
729            // UUID is a stable per-device identifier (e.g. "GPU-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx")
730            // that survives index reordering. For MIG compute instances NVML would return a
731            // MIG-scoped UUID here instead. Used as partition_id for selector matching.
732            let partition_id = dev.uuid().ok().filter(|s| !s.trim().is_empty());
733
734            let (supports_fp16, supports_int8) =
735                match cc.as_deref().and_then(Self::parse_compute_capability) {
736                    Some((major, _minor)) => (major >= 5, major >= 6),
737                    None => (true, true),
738                };
739
740            devices.push(Device {
741                id: index as usize,
742                name,
743                backend: DeviceBackend::Cuda,
744                memory_mb,
745                compute_units: 0,
746                pci_bus_id,
747                partition_id,
748                driver_version: driver_version.clone(),
749                cuda_version: cuda_version.map(|s| s.to_string()),
750                compute_capability: cc,
751                utilization_gpu_pct,
752                temperature_c,
753                supports_fp16,
754                supports_int8,
755            });
756        }
757
758        Some(devices)
759    }
760
761    #[cfg(not(feature = "nvml"))]
762    fn detect_cuda_gpus_nvml(_cuda_version: Option<&str>) -> Option<Vec<Device>> {
763        None
764    }
765
766    /// Select the "best" GPU using a simple heuristic.
767    ///
768    /// Primary key: `memory_mb`.
769    /// Tie-breaker: `compute_capability` when present.
770    pub fn best_gpu(&self) -> Option<&Device> {
771        self.devices
772            .iter()
773            .filter(|d| !matches!(d.backend, DeviceBackend::Cpu))
774            .max_by(|a, b| {
775                let by_mem = a.memory_mb.cmp(&b.memory_mb);
776                if by_mem != std::cmp::Ordering::Equal {
777                    return by_mem;
778                }
779
780                let a_cc = a
781                    .compute_capability
782                    .as_deref()
783                    .and_then(Self::parse_compute_capability)
784                    .unwrap_or((0, 0));
785                let b_cc = b
786                    .compute_capability
787                    .as_deref()
788                    .and_then(Self::parse_compute_capability)
789                    .unwrap_or((0, 0));
790                a_cc.cmp(&b_cc)
791            })
792    }
793
794    pub fn best_gpu_with_preference(&self, preference: &GpuPreference) -> Option<&Device> {
795        self.devices
796            .iter()
797            .find(|d| !matches!(d.backend, DeviceBackend::Cpu) && preference.matches(d))
798    }
799
800    pub fn cuda_devices(&self) -> Vec<&Device> {
801        self.devices
802            .iter()
803            .filter(|d| matches!(d.backend, DeviceBackend::Cuda))
804            .collect()
805    }
806
807    /// Get the best available execution provider.
808    pub fn get_best_provider(&self) -> String {
809        if self.has_cuda {
810            "cuda".to_string()
811        } else if self.has_metal {
812            "metal".to_string()
813        } else if self.has_rocm {
814            "rocm".to_string()
815        } else if self
816            .devices
817            .iter()
818            .any(|d| matches!(d.backend, DeviceBackend::OneApi))
819        {
820            "oneapi".to_string()
821        } else if self.has_directml {
822            "directml".to_string()
823        } else {
824            "cpu".to_string()
825        }
826    }
827
828    /// Check if a specific provider is available.
829    pub fn has_provider(&self, provider: &str) -> bool {
830        let key = provider.trim().to_ascii_lowercase();
831        match key.as_str() {
832            "cuda" => self.has_cuda,
833            "metal" | "coreml" => self.has_metal,
834            "rocm" => self.has_rocm,
835            "directml" => self.has_directml,
836            "oneapi" => self
837                .devices
838                .iter()
839                .any(|d| matches!(d.backend, DeviceBackend::OneApi)),
840            "webgpu" => self
841                .devices
842                .iter()
843                .any(|d| matches!(d.backend, DeviceBackend::WebGpu)),
844            "cpu" => true,
845            other => self
846                .devices
847                .iter()
848                .any(|d| d.backend.to_string().eq_ignore_ascii_case(other)),
849        }
850    }
851}