1use serde::{Deserialize, Serialize};
2use std::process::{Command, Stdio};
3use std::sync::OnceLock;
4use std::time::{Duration, Instant};
5use sys_info;
6use thiserror::Error;
7
8const DEFAULT_PROBE_TIMEOUT: Duration = Duration::from_millis(800);
9const SYSTEM_PROFILER_TIMEOUT: Duration = Duration::from_secs(3);
10const COMMAND_POLL_INTERVAL: Duration = Duration::from_millis(25);
11
12static PROBE_CACHE: OnceLock<DeviceInfo> = OnceLock::new();
13
14#[derive(Debug, Error)]
15pub enum DeviceProbeError {
16 #[error("sys_info error: {0}")]
17 SysInfo(#[from] sys_info::Error),
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25pub enum DeviceBackend {
26 Cpu,
27 Cuda,
28 Metal,
29 Rocm,
30 DirectML,
31 OpenCL,
32 Vulkan,
33 WebGpu,
34 OneApi,
35 Custom(String),
36}
37
38impl DeviceBackend {
39 fn parse(raw: &str) -> Self {
40 let trimmed = raw.trim();
41 if trimmed.is_empty() {
42 return Self::Custom(String::new());
43 }
44
45 match trimmed.to_ascii_lowercase().as_str() {
46 "cpu" => Self::Cpu,
47 "cuda" => Self::Cuda,
48 "metal" => Self::Metal,
49 "rocm" => Self::Rocm,
50 "directml" => Self::DirectML,
51 "opencl" => Self::OpenCL,
52 "vulkan" => Self::Vulkan,
53 "webgpu" => Self::WebGpu,
54 "oneapi" => Self::OneApi,
55 other => Self::Custom(other.to_string()),
56 }
57 }
58}
59
60impl std::fmt::Display for DeviceBackend {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 DeviceBackend::Cpu => write!(f, "cpu"),
64 DeviceBackend::Cuda => write!(f, "cuda"),
65 DeviceBackend::Metal => write!(f, "metal"),
66 DeviceBackend::Rocm => write!(f, "rocm"),
67 DeviceBackend::DirectML => write!(f, "directml"),
68 DeviceBackend::OpenCL => write!(f, "opencl"),
69 DeviceBackend::Vulkan => write!(f, "vulkan"),
70 DeviceBackend::WebGpu => write!(f, "webgpu"),
71 DeviceBackend::OneApi => write!(f, "oneapi"),
72 DeviceBackend::Custom(s) => write!(f, "{s}"),
73 }
74 }
75}
76
77impl Serialize for DeviceBackend {
78 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
79 where
80 S: serde::Serializer,
81 {
82 serializer.serialize_str(&self.to_string())
83 }
84}
85
86impl<'de> Deserialize<'de> for DeviceBackend {
87 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
88 where
89 D: serde::Deserializer<'de>,
90 {
91 let raw = String::deserialize(deserializer)?;
92 Ok(Self::parse(&raw))
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct Device {
99 pub id: usize,
100 pub name: String,
101 pub backend: DeviceBackend,
102
103 pub memory_mb: u64,
104 pub compute_units: u32,
105
106 pub pci_bus_id: Option<String>,
107
108 pub partition_id: Option<String>,
119
120 pub driver_version: Option<String>,
122
123 pub cuda_version: Option<String>,
125
126 pub compute_capability: Option<String>,
128
129 pub utilization_gpu_pct: Option<u32>,
130 pub temperature_c: Option<u32>,
131
132 pub supports_fp16: bool,
133 pub supports_int8: bool,
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct DeviceInfo {
138 pub cpu_cores: u32,
140 pub total_memory: u64,
141 pub os_type: String,
142 pub os_release: String,
143
144 pub has_cuda: bool,
145 pub has_metal: bool,
146 pub has_rocm: bool,
147 pub has_directml: bool,
148
149 pub devices: Vec<Device>,
151}
152
153#[derive(Debug, Clone, PartialEq, Eq)]
154pub enum GpuPreference {
155 BackendId { backend: String, id: usize },
157 PciBusId(String),
159 NameContains(String),
161 Partition(String),
167}
168
169impl GpuPreference {
170 pub fn parse(spec: &str) -> Option<Self> {
171 let trimmed = spec.trim();
172 if trimmed.is_empty() {
173 return None;
174 }
175
176 let lowered = trimmed.to_ascii_lowercase();
177
178 for prefix in &["mig:", "partition:"] {
181 if lowered.starts_with(prefix) {
182 let rest = trimmed[prefix.len()..].trim();
183 if !rest.is_empty() {
184 return Some(Self::Partition(rest.to_string()));
185 }
186 }
187 }
188
189 if let Some((backend, id)) = lowered.split_once(':') {
190 if let Ok(id) = id.trim().parse::<usize>() {
191 return Some(Self::BackendId {
192 backend: backend.trim().to_string(),
193 id,
194 });
195 }
196 }
197
198 if trimmed.contains(':') && trimmed.contains('.') {
200 return Some(Self::PciBusId(trimmed.to_string()));
201 }
202
203 Some(Self::NameContains(trimmed.to_ascii_lowercase()))
204 }
205
206 pub fn matches(&self, device: &Device) -> bool {
208 match self {
209 Self::BackendId { backend, id } => {
210 device.backend.to_string().eq_ignore_ascii_case(backend) && device.id == *id
211 }
212 Self::PciBusId(bus_id) => device
213 .pci_bus_id
214 .as_deref()
215 .is_some_and(|v| v.eq_ignore_ascii_case(bus_id)),
216 Self::NameContains(needle) => {
217 let needle_lower = needle.to_ascii_lowercase();
218 device.name.to_ascii_lowercase().contains(&needle_lower)
219 }
220 Self::Partition(partition_id) => device
221 .partition_id
222 .as_deref()
223 .is_some_and(|v| v.eq_ignore_ascii_case(partition_id)),
224 }
225 }
226}
227
228impl DeviceInfo {
229 pub fn probe() -> Self {
231 if let Some(cached) = PROBE_CACHE.get() {
232 return cached.clone();
233 }
234
235 let probed = Self::try_probe_with_timeout(DEFAULT_PROBE_TIMEOUT).unwrap_or_else(|_| {
236 Self::fallback()
238 });
239
240 let _ = PROBE_CACHE.set(probed.clone());
241 probed
242 }
243
244 pub fn try_probe() -> Result<Self, DeviceProbeError> {
246 Self::try_probe_with_timeout(DEFAULT_PROBE_TIMEOUT)
247 }
248
249 pub fn try_probe_with_timeout(timeout: Duration) -> Result<Self, DeviceProbeError> {
251 let cpu_cores = sys_info::cpu_num()?;
252 let total_memory = sys_info::mem_info()?.total;
253 let os_type = sys_info::os_type().unwrap_or_else(|_| "unknown".to_string());
254 let os_release = sys_info::os_release().unwrap_or_else(|_| "unknown".to_string());
255
256 let mut devices = Vec::new();
257 devices.push(Device {
258 id: 0,
259 name: "CPU".to_string(),
260 backend: DeviceBackend::Cpu,
261 memory_mb: total_memory / 1024,
262 compute_units: cpu_cores,
263 pci_bus_id: None,
264 partition_id: None,
265 driver_version: None,
266 cuda_version: None,
267 compute_capability: None,
268 utilization_gpu_pct: None,
269 temperature_c: None,
270 supports_fp16: true,
271 supports_int8: true,
272 });
273
274 let cuda_version = Self::detect_cuda_version(timeout);
275
276 if let Some(nvml_devices) = Self::detect_cuda_gpus_nvml(cuda_version.as_deref()) {
277 devices.extend(nvml_devices);
278 } else {
279 devices.extend(Self::detect_cuda_gpus(timeout, cuda_version.as_deref()));
280 }
281
282 devices.extend(Self::detect_rocm_gpus(timeout, &os_release));
283 devices.extend(Self::detect_metal(SYSTEM_PROFILER_TIMEOUT, &os_release));
284 devices.extend(Self::detect_directml(timeout, &os_release));
285 devices.extend(Self::detect_oneapi(timeout));
286 devices.extend(Self::detect_webgpu());
287
288 let (has_cuda, has_metal, has_rocm, has_directml) = Self::provider_flags(&devices);
289
290 Ok(Self {
291 cpu_cores,
292 total_memory,
293 os_type,
294 os_release,
295 has_cuda,
296 has_metal,
297 has_rocm,
298 has_directml,
299 devices,
300 })
301 }
302
303 fn fallback() -> Self {
304 Self {
305 cpu_cores: 1,
306 total_memory: 0,
307 os_type: "unknown".to_string(),
308 os_release: "unknown".to_string(),
309 has_cuda: false,
310 has_metal: false,
311 has_rocm: false,
312 has_directml: false,
313 devices: vec![Device {
314 id: 0,
315 name: "CPU".to_string(),
316 backend: DeviceBackend::Cpu,
317 memory_mb: 0,
318 compute_units: 1,
319 pci_bus_id: None,
320 partition_id: None,
321 driver_version: None,
322 cuda_version: None,
323 compute_capability: None,
324 utilization_gpu_pct: None,
325 temperature_c: None,
326 supports_fp16: true,
327 supports_int8: true,
328 }],
329 }
330 }
331
332 fn provider_flags(devices: &[Device]) -> (bool, bool, bool, bool) {
333 let has_cuda = devices
334 .iter()
335 .any(|d| matches!(d.backend, DeviceBackend::Cuda));
336 let has_metal = devices
337 .iter()
338 .any(|d| matches!(d.backend, DeviceBackend::Metal));
339 let has_rocm = devices
340 .iter()
341 .any(|d| matches!(d.backend, DeviceBackend::Rocm));
342 let has_directml = devices
343 .iter()
344 .any(|d| matches!(d.backend, DeviceBackend::DirectML));
345 (has_cuda, has_metal, has_rocm, has_directml)
346 }
347
348 fn run_command_with_timeout(
349 program: &str,
350 args: &[&str],
351 timeout: Duration,
352 ) -> Option<Vec<u8>> {
353 let mut child = Command::new(program)
354 .args(args)
355 .stdout(Stdio::piped())
356 .stderr(Stdio::piped())
357 .spawn()
358 .ok()?;
359
360 let start = Instant::now();
361 loop {
362 if start.elapsed() >= timeout {
363 let _ = child.kill();
364 let _ = child.wait();
365 return None;
366 }
367
368 match child.try_wait().ok()? {
369 Some(_) => break,
370 None => std::thread::sleep(COMMAND_POLL_INTERVAL),
371 }
372 }
373
374 let out = child.wait_with_output().ok()?;
375 if !out.status.success() {
376 return None;
377 }
378 Some(out.stdout)
379 }
380
381 fn parse_cuda_version_from_smi_summary(stdout: &str) -> Option<String> {
382 let needle = "CUDA Version:";
385 let idx = stdout.find(needle)?;
386 let rest = &stdout[idx + needle.len()..];
387 let version = rest
388 .trim_start()
389 .split(|c: char| c.is_whitespace() || c == '|')
390 .next()?
391 .trim();
392 if version.is_empty() {
393 None
394 } else {
395 Some(version.to_string())
396 }
397 }
398
399 fn detect_cuda_version(timeout: Duration) -> Option<String> {
400 let stdout = Self::run_command_with_timeout("nvidia-smi", &[], timeout)?;
401 let text = String::from_utf8_lossy(&stdout);
402 Self::parse_cuda_version_from_smi_summary(&text)
403 }
404
405 fn parse_compute_capability(value: &str) -> Option<(u32, u32)> {
406 let trimmed = value.trim();
407 if trimmed.is_empty() || trimmed.eq_ignore_ascii_case("n/a") {
408 return None;
409 }
410 let (major, minor) = trimmed.split_once('.')?;
411 Some((major.parse().ok()?, minor.parse().ok()?))
412 }
413
414 fn detect_cuda_gpus(timeout: Duration, cuda_version: Option<&str>) -> Vec<Device> {
415 let mut devices = Vec::new();
416
417 let query = "index,name,memory.total,utilization.gpu,temperature.gpu,pci.bus_id,driver_version,compute_cap,uuid";
418 let args = ["--query-gpu", query, "--format=csv,noheader,nounits"];
419
420 let stdout = match Self::run_command_with_timeout("nvidia-smi", &args, timeout) {
421 Some(s) => s,
422 None => return devices,
423 };
424
425 let text = String::from_utf8_lossy(&stdout);
426 for line in text.lines() {
427 let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
428 if parts.len() < 7 {
429 continue;
430 }
431
432 let id = parts[0].parse::<usize>().unwrap_or(0);
433 let name = parts.get(1).copied().unwrap_or("NVIDIA GPU").to_string();
434 let memory_mb = parts
435 .get(2)
436 .and_then(|v| v.parse::<u64>().ok())
437 .unwrap_or(0);
438 let utilization_gpu_pct = parts.get(3).and_then(|v| v.parse::<u32>().ok());
439 let temperature_c = parts.get(4).and_then(|v| v.parse::<u32>().ok());
440
441 let pci_bus_id = parts
442 .get(5)
443 .map(|v| v.trim())
444 .filter(|v| !v.is_empty() && !v.eq_ignore_ascii_case("n/a"))
445 .map(|v| v.to_string());
446
447 let driver_version = parts
448 .get(6)
449 .map(|v| v.trim())
450 .filter(|v| !v.is_empty() && !v.eq_ignore_ascii_case("n/a"))
451 .map(|v| v.to_string());
452
453 let compute_capability = parts
454 .get(7)
455 .map(|v| v.trim())
456 .filter(|v| !v.is_empty() && !v.eq_ignore_ascii_case("n/a"))
457 .map(|v| v.to_string());
458
459 let partition_id = parts
462 .get(8)
463 .map(|v| v.trim())
464 .filter(|v| !v.is_empty() && !v.eq_ignore_ascii_case("n/a"))
465 .map(|v| v.to_string());
466
467 let (supports_fp16, supports_int8) = match compute_capability
468 .as_deref()
469 .and_then(Self::parse_compute_capability)
470 {
471 Some((major, _minor)) => (major >= 5, major >= 6),
472 None => (true, true),
473 };
474
475 devices.push(Device {
476 id,
477 name,
478 backend: DeviceBackend::Cuda,
479 memory_mb,
480 compute_units: 0,
481 pci_bus_id,
482 partition_id,
483 driver_version,
484 cuda_version: cuda_version.map(|s| s.to_string()),
485 compute_capability,
486 utilization_gpu_pct,
487 temperature_c,
488 supports_fp16,
489 supports_int8,
490 });
491 }
492
493 devices
494 }
495
496 fn detect_rocm_gpus(timeout: Duration, os_release: &str) -> Vec<Device> {
497 let mut devices = Vec::new();
498
499 let stdout = match Self::run_command_with_timeout("rocm-smi", &["-i"], timeout) {
500 Some(s) => s,
501 None => return devices,
502 };
503
504 let text = String::from_utf8_lossy(&stdout);
505
506 let mut ids = Vec::new();
508 for line in text.lines() {
509 let line = line.trim();
510 if let Some(start) = line.find("GPU[") {
511 let rest = &line[start + 4..];
512 if let Some(end) = rest.find(']') {
513 if let Ok(id) = rest[..end].parse::<usize>() {
514 if !ids.contains(&id) {
515 ids.push(id);
516 }
517 }
518 }
519 }
520 }
521
522 if ids.is_empty() {
523 for line in text.lines() {
525 let line = line.trim_start();
526 if line.is_empty() {
527 continue;
528 }
529 let first = line.split_whitespace().next().unwrap_or("");
530 if let Ok(id) = first.parse::<usize>() {
531 if !ids.contains(&id) {
532 ids.push(id);
533 }
534 }
535 }
536 }
537
538 ids.sort_unstable();
539 for id in ids {
540 devices.push(Device {
541 id,
542 name: format!("AMD ROCm GPU {id}"),
543 backend: DeviceBackend::Rocm,
544 memory_mb: 0,
545 compute_units: 0,
546 pci_bus_id: None,
547 partition_id: None,
548 driver_version: Some(os_release.to_string()),
549 cuda_version: None,
550 compute_capability: None,
551 utilization_gpu_pct: None,
552 temperature_c: None,
553 supports_fp16: true,
554 supports_int8: true,
555 });
556 }
557
558 devices
559 }
560
561 fn parse_memory_mb(value: &str) -> Option<u64> {
562 let lowered = value.trim().to_ascii_lowercase();
563 if lowered.is_empty() {
564 return None;
565 }
566
567 let mut number = String::new();
568 for ch in lowered.chars() {
569 if ch.is_ascii_digit() || ch == '.' {
570 number.push(ch);
571 } else if !number.is_empty() {
572 break;
573 }
574 }
575
576 let num: f64 = number.parse().ok()?;
577 if lowered.contains("gb") {
578 Some((num * 1024.0) as u64)
579 } else if lowered.contains("mb") {
580 Some(num as u64)
581 } else if lowered.contains("kb") {
582 Some((num / 1024.0) as u64)
583 } else {
584 None
585 }
586 }
587
588 fn detect_metal(timeout: Duration, os_release: &str) -> Vec<Device> {
589 #[cfg(target_os = "macos")]
590 {
591 let mut devs = Vec::new();
592 let stdout = match Self::run_command_with_timeout(
593 "system_profiler",
594 &["SPDisplaysDataType", "-json"],
595 timeout,
596 ) {
597 Some(s) => s,
598 None => return devs,
599 };
600
601 let value: serde_json::Value = match serde_json::from_slice(&stdout) {
602 Ok(v) => v,
603 Err(_) => return devs,
604 };
605
606 let displays = match value.get("SPDisplaysDataType").and_then(|v| v.as_array()) {
607 Some(v) => v,
608 None => return devs,
609 };
610
611 for (idx, item) in displays.iter().enumerate() {
612 let name = item
613 .get("spdisplays_chipset_model")
614 .and_then(|v| v.as_str())
615 .or_else(|| item.get("sppci_model").and_then(|v| v.as_str()))
616 .or_else(|| item.get("_name").and_then(|v| v.as_str()))
617 .unwrap_or("Apple Metal GPU")
618 .to_string();
619
620 let vram_text = item
621 .get("spdisplays_vram")
622 .and_then(|v| v.as_str())
623 .or_else(|| item.get("spdisplays_vram_shared").and_then(|v| v.as_str()))
624 .unwrap_or("");
625
626 let memory_mb = Self::parse_memory_mb(vram_text).unwrap_or(0);
627
628 devs.push(Device {
629 id: idx,
630 name,
631 backend: DeviceBackend::Metal,
632 memory_mb,
633 compute_units: 0,
634 pci_bus_id: None,
635 partition_id: None,
636 driver_version: Some(os_release.to_string()),
637 cuda_version: None,
638 compute_capability: None,
639 utilization_gpu_pct: None,
640 temperature_c: None,
641 supports_fp16: true,
642 supports_int8: true,
643 });
644 }
645
646 devs
647 }
648
649 #[cfg(not(target_os = "macos"))]
650 {
651 let _ = (timeout, os_release);
652 Vec::new()
653 }
654 }
655
656 fn detect_directml(_timeout: Duration, os_release: &str) -> Vec<Device> {
657 #[cfg(target_os = "windows")]
658 {
659 vec![Device {
662 id: 0,
663 name: "DirectML GPU".into(),
664 backend: DeviceBackend::DirectML,
665 memory_mb: 0,
666 compute_units: 0,
667 pci_bus_id: None,
668 partition_id: None,
669 driver_version: Some(os_release.to_string()),
670 cuda_version: None,
671 compute_capability: None,
672 utilization_gpu_pct: None,
673 temperature_c: None,
674 supports_fp16: true,
675 supports_int8: true,
676 }]
677 }
678
679 #[cfg(not(target_os = "windows"))]
680 {
681 let _ = os_release;
682 Vec::new()
683 }
684 }
685
686 fn detect_oneapi(_timeout: Duration) -> Vec<Device> {
687 Vec::new()
691 }
692
693 fn detect_webgpu() -> Vec<Device> {
694 Vec::new()
696 }
697
698 #[cfg(feature = "nvml")]
699 fn detect_cuda_gpus_nvml(cuda_version: Option<&str>) -> Option<Vec<Device>> {
700 use nvml_wrapper::Nvml;
701
702 let nvml = Nvml::init().ok()?;
703 let driver_version = nvml.sys_driver_version().ok();
704 let count = nvml.device_count().ok()?;
705
706 let mut devices = Vec::with_capacity(count as usize);
707 for index in 0..count {
708 let dev = nvml.device_by_index(index).ok()?;
709 let name = dev.name().ok().unwrap_or_else(|| "NVIDIA GPU".to_string());
710 let memory_mb = dev
711 .memory_info()
712 .ok()
713 .map(|m| m.total / (1024 * 1024))
714 .unwrap_or(0);
715 let pci_bus_id = dev
716 .pci_info()
717 .ok()
718 .map(|p| p.bus_id)
719 .filter(|s| !s.trim().is_empty());
720 let utilization_gpu_pct = dev.utilization_rates().ok().map(|u| u.gpu);
721 let temperature_c = dev
722 .temperature(nvml_wrapper::enum_wrappers::device::TemperatureSensor::Gpu)
723 .ok();
724 let cc = dev
725 .cuda_compute_capability()
726 .ok()
727 .map(|(maj, min)| format!("{maj}.{min}"));
728
729 let partition_id = dev.uuid().ok().filter(|s| !s.trim().is_empty());
733
734 let (supports_fp16, supports_int8) =
735 match cc.as_deref().and_then(Self::parse_compute_capability) {
736 Some((major, _minor)) => (major >= 5, major >= 6),
737 None => (true, true),
738 };
739
740 devices.push(Device {
741 id: index as usize,
742 name,
743 backend: DeviceBackend::Cuda,
744 memory_mb,
745 compute_units: 0,
746 pci_bus_id,
747 partition_id,
748 driver_version: driver_version.clone(),
749 cuda_version: cuda_version.map(|s| s.to_string()),
750 compute_capability: cc,
751 utilization_gpu_pct,
752 temperature_c,
753 supports_fp16,
754 supports_int8,
755 });
756 }
757
758 Some(devices)
759 }
760
761 #[cfg(not(feature = "nvml"))]
762 fn detect_cuda_gpus_nvml(_cuda_version: Option<&str>) -> Option<Vec<Device>> {
763 None
764 }
765
766 pub fn best_gpu(&self) -> Option<&Device> {
771 self.devices
772 .iter()
773 .filter(|d| !matches!(d.backend, DeviceBackend::Cpu))
774 .max_by(|a, b| {
775 let by_mem = a.memory_mb.cmp(&b.memory_mb);
776 if by_mem != std::cmp::Ordering::Equal {
777 return by_mem;
778 }
779
780 let a_cc = a
781 .compute_capability
782 .as_deref()
783 .and_then(Self::parse_compute_capability)
784 .unwrap_or((0, 0));
785 let b_cc = b
786 .compute_capability
787 .as_deref()
788 .and_then(Self::parse_compute_capability)
789 .unwrap_or((0, 0));
790 a_cc.cmp(&b_cc)
791 })
792 }
793
794 pub fn best_gpu_with_preference(&self, preference: &GpuPreference) -> Option<&Device> {
795 self.devices
796 .iter()
797 .find(|d| !matches!(d.backend, DeviceBackend::Cpu) && preference.matches(d))
798 }
799
800 pub fn cuda_devices(&self) -> Vec<&Device> {
801 self.devices
802 .iter()
803 .filter(|d| matches!(d.backend, DeviceBackend::Cuda))
804 .collect()
805 }
806
807 pub fn get_best_provider(&self) -> String {
809 if self.has_cuda {
810 "cuda".to_string()
811 } else if self.has_metal {
812 "metal".to_string()
813 } else if self.has_rocm {
814 "rocm".to_string()
815 } else if self
816 .devices
817 .iter()
818 .any(|d| matches!(d.backend, DeviceBackend::OneApi))
819 {
820 "oneapi".to_string()
821 } else if self.has_directml {
822 "directml".to_string()
823 } else {
824 "cpu".to_string()
825 }
826 }
827
828 pub fn has_provider(&self, provider: &str) -> bool {
830 let key = provider.trim().to_ascii_lowercase();
831 match key.as_str() {
832 "cuda" => self.has_cuda,
833 "metal" | "coreml" => self.has_metal,
834 "rocm" => self.has_rocm,
835 "directml" => self.has_directml,
836 "oneapi" => self
837 .devices
838 .iter()
839 .any(|d| matches!(d.backend, DeviceBackend::OneApi)),
840 "webgpu" => self
841 .devices
842 .iter()
843 .any(|d| matches!(d.backend, DeviceBackend::WebGpu)),
844 "cpu" => true,
845 other => self
846 .devices
847 .iter()
848 .any(|d| d.backend.to_string().eq_ignore_ascii_case(other)),
849 }
850 }
851}