1use std::str::FromStr;
14use std::sync::OnceLock;
15
16use crate::error::PiperError;
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20pub enum DeviceKind {
21 Cpu,
22 Cuda,
23 CoreML,
24 DirectML,
25 TensorRT,
26}
27
28impl std::fmt::Display for DeviceKind {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 match self {
31 Self::Cpu => write!(f, "cpu"),
32 Self::Cuda => write!(f, "cuda"),
33 Self::CoreML => write!(f, "coreml"),
34 Self::DirectML => write!(f, "directml"),
35 Self::TensorRT => write!(f, "tensorrt"),
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct DeviceInfo {
43 pub kind: DeviceKind,
44 pub device_id: i32,
45 pub name: String,
46 pub available: bool,
47 pub memory_bytes: Option<u64>,
48}
49
50impl std::fmt::Display for DeviceInfo {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 let id_str = if self.kind == DeviceKind::Cpu {
54 format!("{}", self.kind)
55 } else {
56 format!("{}:{}", self.kind, self.device_id)
57 };
58
59 let mem_str = match self.memory_bytes {
60 Some(bytes) => {
61 let gb = bytes as f64 / (1024.0 * 1024.0 * 1024.0);
62 format!(", {gb:.0}GB")
63 }
64 None => String::new(),
65 };
66
67 let status = if self.available {
68 "available"
69 } else {
70 "unavailable"
71 };
72
73 write!(f, "{id_str} ({}{mem_str}) [{status}]", self.name)
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct DeviceSelection {
80 pub kind: DeviceKind,
81 pub device_id: i32,
82}
83
84impl DeviceSelection {
85 pub fn cpu() -> Self {
87 Self {
88 kind: DeviceKind::Cpu,
89 device_id: 0,
90 }
91 }
92
93 pub fn cuda(device_id: i32) -> Self {
95 Self {
96 kind: DeviceKind::Cuda,
97 device_id,
98 }
99 }
100
101 pub fn coreml() -> Self {
103 Self {
104 kind: DeviceKind::CoreML,
105 device_id: 0,
106 }
107 }
108
109 pub fn directml(device_id: i32) -> Self {
111 Self {
112 kind: DeviceKind::DirectML,
113 device_id,
114 }
115 }
116
117 pub fn auto() -> Self {
128 #[cfg(target_os = "macos")]
129 {
130 if is_device_available(&DeviceKind::CoreML) {
131 return Self::coreml();
132 }
133 }
134
135 #[cfg(target_os = "linux")]
136 {
137 if is_device_available(&DeviceKind::Cuda) {
138 return Self::cuda(0);
139 }
140 }
141
142 #[cfg(target_os = "windows")]
143 {
144 if is_device_available(&DeviceKind::DirectML) {
145 return Self::directml(0);
146 }
147 }
148
149 Self::cpu()
150 }
151}
152
153impl FromStr for DeviceSelection {
158 type Err = PiperError;
159
160 fn from_str(s: &str) -> Result<Self, Self::Err> {
161 let s = s.trim().to_ascii_lowercase();
162
163 if s.is_empty() {
164 return Err(PiperError::InvalidConfig {
165 reason: "empty device string".to_string(),
166 });
167 }
168
169 if s == "auto" {
170 return Ok(Self::auto());
171 }
172
173 let (kind_str, device_id) = if let Some((kind_part, id_part)) = s.split_once(':') {
175 let id: i32 = id_part.parse().map_err(|_| PiperError::InvalidConfig {
176 reason: format!("invalid device id: '{id_part}'"),
177 })?;
178 if id < 0 {
179 return Err(PiperError::InvalidConfig {
180 reason: format!("negative device ID not allowed: {id}"),
181 });
182 }
183 (kind_part, id)
184 } else {
185 (s.as_str(), 0)
186 };
187
188 match kind_str {
189 "cpu" => {
190 if device_id != 0 {
191 return Err(PiperError::InvalidConfig {
192 reason: "cpu does not accept a device ID".to_string(),
193 });
194 }
195 Ok(Self {
196 kind: DeviceKind::Cpu,
197 device_id: 0,
198 })
199 }
200 "cuda" => Ok(Self {
201 kind: DeviceKind::Cuda,
202 device_id,
203 }),
204 "coreml" => {
205 if device_id != 0 {
206 return Err(PiperError::InvalidConfig {
207 reason: "coreml does not accept a device ID".to_string(),
208 });
209 }
210 Ok(Self {
211 kind: DeviceKind::CoreML,
212 device_id: 0,
213 })
214 }
215 "directml" => Ok(Self {
216 kind: DeviceKind::DirectML,
217 device_id,
218 }),
219 "tensorrt" => Ok(Self {
220 kind: DeviceKind::TensorRT,
221 device_id,
222 }),
223 _ => Err(PiperError::InvalidConfig {
224 reason: format!("unknown device kind: '{kind_str}'"),
225 }),
226 }
227 }
228}
229
230impl std::fmt::Display for DeviceSelection {
231 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232 if self.kind == DeviceKind::Cpu {
233 write!(f, "cpu")
234 } else {
235 write!(f, "{}:{}", self.kind, self.device_id)
236 }
237 }
238}
239
240pub fn enumerate_devices() -> &'static [DeviceInfo] {
247 static DEVICES: OnceLock<Vec<DeviceInfo>> = OnceLock::new();
248 DEVICES.get_or_init(|| {
249 let mut devices = Vec::new();
250
251 devices.push(DeviceInfo {
253 kind: DeviceKind::Cpu,
254 device_id: 0,
255 name: "CPU".to_string(),
256 available: true,
257 memory_bytes: None,
258 });
259
260 #[cfg(feature = "cuda")]
262 {
263 devices.push(DeviceInfo {
267 kind: DeviceKind::Cuda,
268 device_id: 0,
269 name: "CUDA Device 0".to_string(),
270 available: true,
271 memory_bytes: None,
272 });
273 }
274
275 #[cfg(all(feature = "coreml", target_os = "macos"))]
277 {
278 devices.push(DeviceInfo {
279 kind: DeviceKind::CoreML,
280 device_id: 0,
281 name: "Apple Neural Engine / GPU".to_string(),
282 available: true,
283 memory_bytes: None,
284 });
285 }
286
287 #[cfg(all(feature = "directml", target_os = "windows"))]
289 {
290 devices.push(DeviceInfo {
291 kind: DeviceKind::DirectML,
292 device_id: 0,
293 name: "DirectML Device 0".to_string(),
294 available: true,
295 memory_bytes: None,
296 });
297 }
298
299 #[cfg(feature = "tensorrt")]
301 {
302 devices.push(DeviceInfo {
303 kind: DeviceKind::TensorRT,
304 device_id: 0,
305 name: "TensorRT Device 0".to_string(),
306 available: true,
307 memory_bytes: None,
308 });
309 }
310
311 devices
312 })
313}
314
315pub fn is_device_available(kind: &DeviceKind) -> bool {
325 struct Availability {
327 cuda: bool,
328 coreml: bool,
329 directml: bool,
330 tensorrt: bool,
331 }
332
333 static AVAIL: OnceLock<Availability> = OnceLock::new();
334 let avail = AVAIL.get_or_init(|| Availability {
335 cuda: {
336 #[cfg(feature = "cuda")]
337 {
338 true
339 }
340 #[cfg(not(feature = "cuda"))]
341 {
342 false
343 }
344 },
345 coreml: {
346 #[cfg(all(feature = "coreml", target_os = "macos"))]
347 {
348 true
349 }
350 #[cfg(not(all(feature = "coreml", target_os = "macos")))]
351 {
352 false
353 }
354 },
355 directml: {
356 #[cfg(all(feature = "directml", target_os = "windows"))]
357 {
358 true
359 }
360 #[cfg(not(all(feature = "directml", target_os = "windows")))]
361 {
362 false
363 }
364 },
365 tensorrt: {
366 #[cfg(feature = "tensorrt")]
367 {
368 true
369 }
370 #[cfg(not(feature = "tensorrt"))]
371 {
372 false
373 }
374 },
375 });
376
377 match kind {
378 DeviceKind::Cpu => true,
379 DeviceKind::Cuda => avail.cuda,
380 DeviceKind::CoreML => avail.coreml,
381 DeviceKind::DirectML => avail.directml,
382 DeviceKind::TensorRT => avail.tensorrt,
383 }
384}
385
386pub fn recommended_device() -> DeviceSelection {
391 DeviceSelection::auto()
392}
393
394impl From<DeviceSelection> for crate::gpu::DeviceType {
401 fn from(sel: DeviceSelection) -> Self {
402 match sel.kind {
403 DeviceKind::Cpu => crate::gpu::DeviceType::Cpu,
404 DeviceKind::Cuda => crate::gpu::DeviceType::Cuda {
405 device_id: sel.device_id,
406 },
407 DeviceKind::CoreML => crate::gpu::DeviceType::CoreML,
408 DeviceKind::DirectML => crate::gpu::DeviceType::DirectML {
409 device_id: sel.device_id,
410 },
411 DeviceKind::TensorRT => crate::gpu::DeviceType::TensorRT {
412 device_id: sel.device_id,
413 },
414 }
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[test]
425 fn test_from_device_selection_cpu() {
426 let sel = DeviceSelection::cpu();
427 let dt: crate::gpu::DeviceType = sel.into();
428 assert_eq!(dt, crate::gpu::DeviceType::Cpu);
429 }
430
431 #[test]
432 fn test_from_device_selection_cuda() {
433 let sel = DeviceSelection::cuda(2);
434 let dt: crate::gpu::DeviceType = sel.into();
435 assert_eq!(dt, crate::gpu::DeviceType::Cuda { device_id: 2 });
436 }
437
438 #[test]
439 fn test_from_device_selection_coreml() {
440 let sel = DeviceSelection::coreml();
441 let dt: crate::gpu::DeviceType = sel.into();
442 assert_eq!(dt, crate::gpu::DeviceType::CoreML);
443 }
444
445 #[test]
446 fn test_from_device_selection_directml() {
447 let sel = DeviceSelection::directml(1);
448 let dt: crate::gpu::DeviceType = sel.into();
449 assert_eq!(dt, crate::gpu::DeviceType::DirectML { device_id: 1 });
450 }
451
452 #[test]
453 fn test_from_device_selection_tensorrt() {
454 let sel = DeviceSelection {
455 kind: DeviceKind::TensorRT,
456 device_id: 0,
457 };
458 let dt: crate::gpu::DeviceType = sel.into();
459 assert_eq!(dt, crate::gpu::DeviceType::TensorRT { device_id: 0 });
460 }
461
462 #[test]
465 fn test_from_str_cpu() {
466 let sel = DeviceSelection::from_str("cpu").unwrap();
467 assert_eq!(sel.kind, DeviceKind::Cpu);
468 assert_eq!(sel.device_id, 0);
469 }
470
471 #[test]
472 fn test_from_str_cuda_default() {
473 let sel = DeviceSelection::from_str("cuda").unwrap();
474 assert_eq!(sel.kind, DeviceKind::Cuda);
475 assert_eq!(sel.device_id, 0);
476 }
477
478 #[test]
479 fn test_from_str_cuda_with_id() {
480 let sel = DeviceSelection::from_str("cuda:1").unwrap();
481 assert_eq!(sel.kind, DeviceKind::Cuda);
482 assert_eq!(sel.device_id, 1);
483 }
484
485 #[test]
486 fn test_from_str_cuda_zero() {
487 let sel = DeviceSelection::from_str("cuda:0").unwrap();
488 assert_eq!(sel.kind, DeviceKind::Cuda);
489 assert_eq!(sel.device_id, 0);
490 }
491
492 #[test]
493 fn test_from_str_coreml() {
494 let sel = DeviceSelection::from_str("coreml").unwrap();
495 assert_eq!(sel.kind, DeviceKind::CoreML);
496 assert_eq!(sel.device_id, 0);
497 }
498
499 #[test]
500 fn test_from_str_directml() {
501 let sel = DeviceSelection::from_str("directml").unwrap();
502 assert_eq!(sel.kind, DeviceKind::DirectML);
503 assert_eq!(sel.device_id, 0);
504 }
505
506 #[test]
507 fn test_from_str_directml_with_id() {
508 let sel = DeviceSelection::from_str("directml:2").unwrap();
509 assert_eq!(sel.kind, DeviceKind::DirectML);
510 assert_eq!(sel.device_id, 2);
511 }
512
513 #[test]
514 fn test_from_str_tensorrt() {
515 let sel = DeviceSelection::from_str("tensorrt").unwrap();
516 assert_eq!(sel.kind, DeviceKind::TensorRT);
517 assert_eq!(sel.device_id, 0);
518 }
519
520 #[test]
521 fn test_from_str_auto() {
522 let sel = DeviceSelection::from_str("auto").unwrap();
523 assert!(
525 sel.kind == DeviceKind::Cpu
526 || sel.kind == DeviceKind::Cuda
527 || sel.kind == DeviceKind::CoreML
528 || sel.kind == DeviceKind::DirectML
529 );
530 }
531
532 #[test]
533 fn test_from_str_case_insensitive() {
534 let sel = DeviceSelection::from_str("CUDA").unwrap();
535 assert_eq!(sel.kind, DeviceKind::Cuda);
536 assert_eq!(sel.device_id, 0);
537
538 let sel2 = DeviceSelection::from_str("Cuda:1").unwrap();
539 assert_eq!(sel2.kind, DeviceKind::Cuda);
540 assert_eq!(sel2.device_id, 1);
541
542 let sel3 = DeviceSelection::from_str("CPU").unwrap();
543 assert_eq!(sel3.kind, DeviceKind::Cpu);
544
545 let sel4 = DeviceSelection::from_str("CoreML").unwrap();
546 assert_eq!(sel4.kind, DeviceKind::CoreML);
547 }
548
549 #[test]
552 fn test_from_str_invalid() {
553 let err = DeviceSelection::from_str("invalid");
554 assert!(err.is_err());
555 }
556
557 #[test]
558 fn test_from_str_gpu_unknown() {
559 let err = DeviceSelection::from_str("gpu");
560 assert!(err.is_err());
561 }
562
563 #[test]
564 fn test_from_str_empty() {
565 let err = DeviceSelection::from_str("");
566 assert!(err.is_err());
567 }
568
569 #[test]
570 fn test_from_str_bad_device_id() {
571 let err = DeviceSelection::from_str("cuda:abc");
572 assert!(err.is_err());
573 }
574
575 #[test]
578 fn test_constructor_cpu() {
579 let sel = DeviceSelection::cpu();
580 assert_eq!(sel.kind, DeviceKind::Cpu);
581 assert_eq!(sel.device_id, 0);
582 }
583
584 #[test]
585 fn test_constructor_cuda() {
586 let sel = DeviceSelection::cuda(3);
587 assert_eq!(sel.kind, DeviceKind::Cuda);
588 assert_eq!(sel.device_id, 3);
589 }
590
591 #[test]
592 fn test_constructor_coreml() {
593 let sel = DeviceSelection::coreml();
594 assert_eq!(sel.kind, DeviceKind::CoreML);
595 assert_eq!(sel.device_id, 0);
596 }
597
598 #[test]
599 fn test_constructor_directml() {
600 let sel = DeviceSelection::directml(1);
601 assert_eq!(sel.kind, DeviceKind::DirectML);
602 assert_eq!(sel.device_id, 1);
603 }
604
605 #[test]
608 fn test_device_kind_display() {
609 assert_eq!(DeviceKind::Cpu.to_string(), "cpu");
610 assert_eq!(DeviceKind::Cuda.to_string(), "cuda");
611 assert_eq!(DeviceKind::CoreML.to_string(), "coreml");
612 assert_eq!(DeviceKind::DirectML.to_string(), "directml");
613 assert_eq!(DeviceKind::TensorRT.to_string(), "tensorrt");
614 }
615
616 #[test]
619 fn test_device_info_display_cpu() {
620 let info = DeviceInfo {
621 kind: DeviceKind::Cpu,
622 device_id: 0,
623 name: "CPU".to_string(),
624 available: true,
625 memory_bytes: None,
626 };
627 let s = info.to_string();
628 assert_eq!(s, "cpu (CPU) [available]");
629 }
630
631 #[test]
632 fn test_device_info_display_cuda_with_memory() {
633 let info = DeviceInfo {
634 kind: DeviceKind::Cuda,
635 device_id: 0,
636 name: "NVIDIA GeForce RTX 3090".to_string(),
637 available: true,
638 memory_bytes: Some(24 * 1024 * 1024 * 1024), };
640 let s = info.to_string();
641 assert_eq!(s, "cuda:0 (NVIDIA GeForce RTX 3090, 24GB) [available]");
642 }
643
644 #[test]
645 fn test_device_info_display_unavailable() {
646 let info = DeviceInfo {
647 kind: DeviceKind::Cuda,
648 device_id: 1,
649 name: "CUDA Device 1".to_string(),
650 available: false,
651 memory_bytes: None,
652 };
653 let s = info.to_string();
654 assert_eq!(s, "cuda:1 (CUDA Device 1) [unavailable]");
655 }
656
657 #[test]
660 fn test_enumerate_devices_always_includes_cpu() {
661 let devices = enumerate_devices();
662 assert!(!devices.is_empty());
663 assert!(devices.iter().any(|d| d.kind == DeviceKind::Cpu));
664 let cpu = devices.iter().find(|d| d.kind == DeviceKind::Cpu).unwrap();
666 assert!(cpu.available);
667 }
668
669 #[test]
672 fn test_cpu_always_available() {
673 assert!(is_device_available(&DeviceKind::Cpu));
674 }
675
676 #[test]
679 fn test_auto_returns_valid_device() {
680 let sel = DeviceSelection::auto();
681 assert!(
683 sel.kind == DeviceKind::Cpu
684 || sel.kind == DeviceKind::Cuda
685 || sel.kind == DeviceKind::CoreML
686 || sel.kind == DeviceKind::DirectML
687 );
688 assert!(sel.device_id >= 0);
689 }
690
691 #[test]
692 fn test_recommended_device_returns_valid() {
693 let sel = recommended_device();
694 assert!(
695 sel.kind == DeviceKind::Cpu
696 || sel.kind == DeviceKind::Cuda
697 || sel.kind == DeviceKind::CoreML
698 || sel.kind == DeviceKind::DirectML
699 );
700 assert!(sel.device_id >= 0);
701 }
702
703 #[test]
706 fn test_device_selection_display_cpu() {
707 let sel = DeviceSelection::cpu();
708 assert_eq!(sel.to_string(), "cpu");
709 }
710
711 #[test]
712 fn test_device_selection_display_cuda() {
713 let sel = DeviceSelection::cuda(1);
714 assert_eq!(sel.to_string(), "cuda:1");
715 }
716
717 #[test]
720 fn test_device_kind_eq_and_hash() {
721 use std::collections::HashSet;
722 let mut set = HashSet::new();
723 set.insert(DeviceKind::Cpu);
724 set.insert(DeviceKind::Cuda);
725 set.insert(DeviceKind::Cpu); assert_eq!(set.len(), 2);
727 assert!(set.contains(&DeviceKind::Cpu));
728 assert!(set.contains(&DeviceKind::Cuda));
729 assert!(!set.contains(&DeviceKind::CoreML));
730 }
731
732 #[test]
737 fn test_device_selection_from_str_negative_id() {
738 let result = DeviceSelection::from_str("cuda:-1");
740 assert!(result.is_err());
741 let err_msg = result.unwrap_err().to_string();
742 assert!(
743 err_msg.contains("negative device ID"),
744 "error should mention negative device ID, got: {err_msg}"
745 );
746 }
747
748 #[test]
749 fn test_device_selection_from_str_cpu_with_id_rejected() {
750 let result = DeviceSelection::from_str("cpu:1");
752 assert!(result.is_err());
753 let err_msg = result.unwrap_err().to_string();
754 assert!(
755 err_msg.contains("cpu does not accept a device ID"),
756 "error should mention cpu device ID, got: {err_msg}"
757 );
758 }
759
760 #[test]
761 fn test_device_selection_from_str_cpu_zero_ok() {
762 let sel = DeviceSelection::from_str("cpu:0").unwrap();
764 assert_eq!(sel.kind, DeviceKind::Cpu);
765 assert_eq!(sel.device_id, 0);
766 }
767
768 #[test]
769 fn test_device_selection_from_str_coreml_with_id_rejected() {
770 let result = DeviceSelection::from_str("coreml:1");
772 assert!(result.is_err());
773 let err_msg = result.unwrap_err().to_string();
774 assert!(
775 err_msg.contains("coreml does not accept a device ID"),
776 "error should mention coreml device ID, got: {err_msg}"
777 );
778 }
779
780 #[test]
781 fn test_device_selection_from_str_coreml_zero_ok() {
782 let sel = DeviceSelection::from_str("coreml:0").unwrap();
784 assert_eq!(sel.kind, DeviceKind::CoreML);
785 assert_eq!(sel.device_id, 0);
786 }
787
788 #[test]
789 fn test_device_selection_display_roundtrip() {
790 let cases = vec![
792 DeviceSelection::cpu(),
793 DeviceSelection::cuda(0),
794 DeviceSelection::cuda(3),
795 DeviceSelection::coreml(),
796 DeviceSelection::directml(0),
797 DeviceSelection::directml(2),
798 ];
799 for sel in cases {
800 let displayed = sel.to_string();
801 let parsed = DeviceSelection::from_str(&displayed).unwrap();
802 assert_eq!(
803 parsed.kind, sel.kind,
804 "roundtrip kind failed for '{displayed}'"
805 );
806 assert_eq!(
807 parsed.device_id, sel.device_id,
808 "roundtrip id failed for '{displayed}'"
809 );
810 }
811 }
812
813 #[test]
814 fn test_enumerate_devices_no_duplicates() {
815 let devices = enumerate_devices();
816 let mut seen_kinds: Vec<DeviceKind> = Vec::new();
817 for d in devices {
818 assert!(
819 !seen_kinds.contains(&d.kind),
820 "duplicate device kind: {:?}",
821 d.kind
822 );
823 seen_kinds.push(d.kind.clone());
824 }
825 }
826
827 #[test]
828 fn test_device_info_memory_display_large() {
829 let memory: u64 = 80 * 1024 * 1024 * 1024;
831 let info = DeviceInfo {
832 kind: DeviceKind::Cuda,
833 device_id: 0,
834 name: "NVIDIA A100".to_string(),
835 available: true,
836 memory_bytes: Some(memory),
837 };
838 let s = info.to_string();
839 assert!(s.contains("80GB"), "expected '80GB' in: {s}");
840 assert!(s.contains("[available]"));
841 assert!(s.contains("cuda:0"));
842 }
843}