Skip to main content

piper_plus/
device.rs

1//! High-level compute device enumeration and selection.
2//!
3//! Provides a user-facing interface for discovering and selecting compute
4//! devices (CPU, CUDA, CoreML, DirectML) for ONNX Runtime inference.
5//!
6//! This module operates at the **application layer** -- it handles user input
7//! parsing, device discovery, and display formatting.  The actual ONNX Runtime
8//! `ExecutionProvider` configuration lives in [`crate::gpu`], which is the
9//! **low-level ort integration layer**.  Use [`From<DeviceSelection>`] to
10//! convert a high-level selection into a [`crate::gpu::DeviceType`] suitable
11//! for passing to [`crate::gpu::configure_session_builder`].
12
13use std::str::FromStr;
14use std::sync::OnceLock;
15
16use crate::error::PiperError;
17
18/// Compute device type.
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20pub enum DeviceKind {
21    Cpu,
22    Cuda,
23    CoreML,
24    DirectML,
25    TensorRT,
26}
27
28impl std::fmt::Display for DeviceKind {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        match self {
31            Self::Cpu => write!(f, "cpu"),
32            Self::Cuda => write!(f, "cuda"),
33            Self::CoreML => write!(f, "coreml"),
34            Self::DirectML => write!(f, "directml"),
35            Self::TensorRT => write!(f, "tensorrt"),
36        }
37    }
38}
39
40/// Information about a compute device.
41#[derive(Debug, Clone)]
42pub struct DeviceInfo {
43    pub kind: DeviceKind,
44    pub device_id: i32,
45    pub name: String,
46    pub available: bool,
47    pub memory_bytes: Option<u64>,
48}
49
50impl std::fmt::Display for DeviceInfo {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        // e.g., "cuda:0 (NVIDIA GeForce RTX 3090, 24GB) [available]"
53        let id_str = if self.kind == DeviceKind::Cpu {
54            format!("{}", self.kind)
55        } else {
56            format!("{}:{}", self.kind, self.device_id)
57        };
58
59        let mem_str = match self.memory_bytes {
60            Some(bytes) => {
61                let gb = bytes as f64 / (1024.0 * 1024.0 * 1024.0);
62                format!(", {gb:.0}GB")
63            }
64            None => String::new(),
65        };
66
67        let status = if self.available {
68            "available"
69        } else {
70            "unavailable"
71        };
72
73        write!(f, "{id_str} ({}{mem_str}) [{status}]", self.name)
74    }
75}
76
77/// Device selection specification.
78#[derive(Debug, Clone)]
79pub struct DeviceSelection {
80    pub kind: DeviceKind,
81    pub device_id: i32,
82}
83
84impl DeviceSelection {
85    /// Select CPU device.
86    pub fn cpu() -> Self {
87        Self {
88            kind: DeviceKind::Cpu,
89            device_id: 0,
90        }
91    }
92
93    /// Select CUDA device by index.
94    pub fn cuda(device_id: i32) -> Self {
95        Self {
96            kind: DeviceKind::Cuda,
97            device_id,
98        }
99    }
100
101    /// Select CoreML device.
102    pub fn coreml() -> Self {
103        Self {
104            kind: DeviceKind::CoreML,
105            device_id: 0,
106        }
107    }
108
109    /// Select DirectML device by index.
110    pub fn directml(device_id: i32) -> Self {
111        Self {
112            kind: DeviceKind::DirectML,
113            device_id,
114        }
115    }
116
117    /// Auto-select the best available device.
118    ///
119    /// Priority by platform:
120    /// - macOS: CoreML > CPU
121    /// - Linux: CUDA > CPU
122    /// - Windows: DirectML > CPU
123    /// - Other: CPU
124    ///
125    /// Feature flags are checked at compile time; if the preferred accelerator
126    /// was not compiled in, falls back to CPU.
127    pub fn auto() -> Self {
128        #[cfg(target_os = "macos")]
129        {
130            if is_device_available(&DeviceKind::CoreML) {
131                return Self::coreml();
132            }
133        }
134
135        #[cfg(target_os = "linux")]
136        {
137            if is_device_available(&DeviceKind::Cuda) {
138                return Self::cuda(0);
139            }
140        }
141
142        #[cfg(target_os = "windows")]
143        {
144            if is_device_available(&DeviceKind::DirectML) {
145                return Self::directml(0);
146            }
147        }
148
149        Self::cpu()
150    }
151}
152
153/// Parse from string: `"cpu"`, `"cuda"`, `"cuda:0"`, `"cuda:1"`, `"coreml"`,
154/// `"directml"`, `"directml:0"`, `"tensorrt"`, `"tensorrt:0"`, `"auto"`.
155///
156/// Parsing is case-insensitive.
157impl FromStr for DeviceSelection {
158    type Err = PiperError;
159
160    fn from_str(s: &str) -> Result<Self, Self::Err> {
161        let s = s.trim().to_ascii_lowercase();
162
163        if s.is_empty() {
164            return Err(PiperError::InvalidConfig {
165                reason: "empty device string".to_string(),
166            });
167        }
168
169        if s == "auto" {
170            return Ok(Self::auto());
171        }
172
173        // Split on ':' to extract optional device_id
174        let (kind_str, device_id) = if let Some((kind_part, id_part)) = s.split_once(':') {
175            let id: i32 = id_part.parse().map_err(|_| PiperError::InvalidConfig {
176                reason: format!("invalid device id: '{id_part}'"),
177            })?;
178            if id < 0 {
179                return Err(PiperError::InvalidConfig {
180                    reason: format!("negative device ID not allowed: {id}"),
181                });
182            }
183            (kind_part, id)
184        } else {
185            (s.as_str(), 0)
186        };
187
188        match kind_str {
189            "cpu" => {
190                if device_id != 0 {
191                    return Err(PiperError::InvalidConfig {
192                        reason: "cpu does not accept a device ID".to_string(),
193                    });
194                }
195                Ok(Self {
196                    kind: DeviceKind::Cpu,
197                    device_id: 0,
198                })
199            }
200            "cuda" => Ok(Self {
201                kind: DeviceKind::Cuda,
202                device_id,
203            }),
204            "coreml" => {
205                if device_id != 0 {
206                    return Err(PiperError::InvalidConfig {
207                        reason: "coreml does not accept a device ID".to_string(),
208                    });
209                }
210                Ok(Self {
211                    kind: DeviceKind::CoreML,
212                    device_id: 0,
213                })
214            }
215            "directml" => Ok(Self {
216                kind: DeviceKind::DirectML,
217                device_id,
218            }),
219            "tensorrt" => Ok(Self {
220                kind: DeviceKind::TensorRT,
221                device_id,
222            }),
223            _ => Err(PiperError::InvalidConfig {
224                reason: format!("unknown device kind: '{kind_str}'"),
225            }),
226        }
227    }
228}
229
230impl std::fmt::Display for DeviceSelection {
231    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232        if self.kind == DeviceKind::Cpu {
233            write!(f, "cpu")
234        } else {
235            write!(f, "{}:{}", self.kind, self.device_id)
236        }
237    }
238}
239
240/// Enumerate all available compute devices on this system.
241///
242/// CPU is always included. Accelerators are included only when the
243/// corresponding feature flag is compiled in.
244///
245/// Results are computed once and cached for the lifetime of the process.
246pub fn enumerate_devices() -> &'static [DeviceInfo] {
247    static DEVICES: OnceLock<Vec<DeviceInfo>> = OnceLock::new();
248    DEVICES.get_or_init(|| {
249        let mut devices = Vec::new();
250
251        // CPU is always available
252        devices.push(DeviceInfo {
253            kind: DeviceKind::Cpu,
254            device_id: 0,
255            name: "CPU".to_string(),
256            available: true,
257            memory_bytes: None,
258        });
259
260        // CUDA devices
261        #[cfg(feature = "cuda")]
262        {
263            // When the cuda feature is compiled, report at least device 0.
264            // Actual GPU enumeration would require the CUDA runtime; for now
265            // we advertise a single device whose availability is best-effort.
266            devices.push(DeviceInfo {
267                kind: DeviceKind::Cuda,
268                device_id: 0,
269                name: "CUDA Device 0".to_string(),
270                available: true,
271                memory_bytes: None,
272            });
273        }
274
275        // CoreML (macOS only)
276        #[cfg(all(feature = "coreml", target_os = "macos"))]
277        {
278            devices.push(DeviceInfo {
279                kind: DeviceKind::CoreML,
280                device_id: 0,
281                name: "Apple Neural Engine / GPU".to_string(),
282                available: true,
283                memory_bytes: None,
284            });
285        }
286
287        // DirectML (Windows only)
288        #[cfg(all(feature = "directml", target_os = "windows"))]
289        {
290            devices.push(DeviceInfo {
291                kind: DeviceKind::DirectML,
292                device_id: 0,
293                name: "DirectML Device 0".to_string(),
294                available: true,
295                memory_bytes: None,
296            });
297        }
298
299        // TensorRT (Linux typically)
300        #[cfg(feature = "tensorrt")]
301        {
302            devices.push(DeviceInfo {
303                kind: DeviceKind::TensorRT,
304                device_id: 0,
305                name: "TensorRT Device 0".to_string(),
306                available: true,
307                memory_bytes: None,
308            });
309        }
310
311        devices
312    })
313}
314
315/// Check if a specific device kind is available.
316///
317/// A device is considered available when both:
318/// 1. The corresponding feature flag was compiled in, and
319/// 2. The runtime can plausibly support it (e.g., correct OS).
320///
321/// CPU is always available.
322///
323/// Results are computed once and cached for the lifetime of the process.
324pub fn is_device_available(kind: &DeviceKind) -> bool {
325    /// Cached availability results for all device kinds.
326    struct Availability {
327        cuda: bool,
328        coreml: bool,
329        directml: bool,
330        tensorrt: bool,
331    }
332
333    static AVAIL: OnceLock<Availability> = OnceLock::new();
334    let avail = AVAIL.get_or_init(|| Availability {
335        cuda: {
336            #[cfg(feature = "cuda")]
337            {
338                true
339            }
340            #[cfg(not(feature = "cuda"))]
341            {
342                false
343            }
344        },
345        coreml: {
346            #[cfg(all(feature = "coreml", target_os = "macos"))]
347            {
348                true
349            }
350            #[cfg(not(all(feature = "coreml", target_os = "macos")))]
351            {
352                false
353            }
354        },
355        directml: {
356            #[cfg(all(feature = "directml", target_os = "windows"))]
357            {
358                true
359            }
360            #[cfg(not(all(feature = "directml", target_os = "windows")))]
361            {
362                false
363            }
364        },
365        tensorrt: {
366            #[cfg(feature = "tensorrt")]
367            {
368                true
369            }
370            #[cfg(not(feature = "tensorrt"))]
371            {
372                false
373            }
374        },
375    });
376
377    match kind {
378        DeviceKind::Cpu => true,
379        DeviceKind::Cuda => avail.cuda,
380        DeviceKind::CoreML => avail.coreml,
381        DeviceKind::DirectML => avail.directml,
382        DeviceKind::TensorRT => avail.tensorrt,
383    }
384}
385
386/// Get the recommended device for this platform.
387///
388/// This is equivalent to [`DeviceSelection::auto()`] but returned as a
389/// standalone function for convenience.
390pub fn recommended_device() -> DeviceSelection {
391    DeviceSelection::auto()
392}
393
394// ---------------------------------------------------------------------------
395// Bridge to gpu::DeviceType
396// ---------------------------------------------------------------------------
397
398/// Convert a high-level [`DeviceSelection`] into the low-level
399/// [`crate::gpu::DeviceType`] used by the ONNX Runtime session builder.
400impl From<DeviceSelection> for crate::gpu::DeviceType {
401    fn from(sel: DeviceSelection) -> Self {
402        match sel.kind {
403            DeviceKind::Cpu => crate::gpu::DeviceType::Cpu,
404            DeviceKind::Cuda => crate::gpu::DeviceType::Cuda {
405                device_id: sel.device_id,
406            },
407            DeviceKind::CoreML => crate::gpu::DeviceType::CoreML,
408            DeviceKind::DirectML => crate::gpu::DeviceType::DirectML {
409                device_id: sel.device_id,
410            },
411            DeviceKind::TensorRT => crate::gpu::DeviceType::TensorRT {
412                device_id: sel.device_id,
413            },
414        }
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    // --- DeviceSelection -> gpu::DeviceType conversion ---
423
424    #[test]
425    fn test_from_device_selection_cpu() {
426        let sel = DeviceSelection::cpu();
427        let dt: crate::gpu::DeviceType = sel.into();
428        assert_eq!(dt, crate::gpu::DeviceType::Cpu);
429    }
430
431    #[test]
432    fn test_from_device_selection_cuda() {
433        let sel = DeviceSelection::cuda(2);
434        let dt: crate::gpu::DeviceType = sel.into();
435        assert_eq!(dt, crate::gpu::DeviceType::Cuda { device_id: 2 });
436    }
437
438    #[test]
439    fn test_from_device_selection_coreml() {
440        let sel = DeviceSelection::coreml();
441        let dt: crate::gpu::DeviceType = sel.into();
442        assert_eq!(dt, crate::gpu::DeviceType::CoreML);
443    }
444
445    #[test]
446    fn test_from_device_selection_directml() {
447        let sel = DeviceSelection::directml(1);
448        let dt: crate::gpu::DeviceType = sel.into();
449        assert_eq!(dt, crate::gpu::DeviceType::DirectML { device_id: 1 });
450    }
451
452    #[test]
453    fn test_from_device_selection_tensorrt() {
454        let sel = DeviceSelection {
455            kind: DeviceKind::TensorRT,
456            device_id: 0,
457        };
458        let dt: crate::gpu::DeviceType = sel.into();
459        assert_eq!(dt, crate::gpu::DeviceType::TensorRT { device_id: 0 });
460    }
461
462    // --- DeviceSelection::from_str ---
463
464    #[test]
465    fn test_from_str_cpu() {
466        let sel = DeviceSelection::from_str("cpu").unwrap();
467        assert_eq!(sel.kind, DeviceKind::Cpu);
468        assert_eq!(sel.device_id, 0);
469    }
470
471    #[test]
472    fn test_from_str_cuda_default() {
473        let sel = DeviceSelection::from_str("cuda").unwrap();
474        assert_eq!(sel.kind, DeviceKind::Cuda);
475        assert_eq!(sel.device_id, 0);
476    }
477
478    #[test]
479    fn test_from_str_cuda_with_id() {
480        let sel = DeviceSelection::from_str("cuda:1").unwrap();
481        assert_eq!(sel.kind, DeviceKind::Cuda);
482        assert_eq!(sel.device_id, 1);
483    }
484
485    #[test]
486    fn test_from_str_cuda_zero() {
487        let sel = DeviceSelection::from_str("cuda:0").unwrap();
488        assert_eq!(sel.kind, DeviceKind::Cuda);
489        assert_eq!(sel.device_id, 0);
490    }
491
492    #[test]
493    fn test_from_str_coreml() {
494        let sel = DeviceSelection::from_str("coreml").unwrap();
495        assert_eq!(sel.kind, DeviceKind::CoreML);
496        assert_eq!(sel.device_id, 0);
497    }
498
499    #[test]
500    fn test_from_str_directml() {
501        let sel = DeviceSelection::from_str("directml").unwrap();
502        assert_eq!(sel.kind, DeviceKind::DirectML);
503        assert_eq!(sel.device_id, 0);
504    }
505
506    #[test]
507    fn test_from_str_directml_with_id() {
508        let sel = DeviceSelection::from_str("directml:2").unwrap();
509        assert_eq!(sel.kind, DeviceKind::DirectML);
510        assert_eq!(sel.device_id, 2);
511    }
512
513    #[test]
514    fn test_from_str_tensorrt() {
515        let sel = DeviceSelection::from_str("tensorrt").unwrap();
516        assert_eq!(sel.kind, DeviceKind::TensorRT);
517        assert_eq!(sel.device_id, 0);
518    }
519
520    #[test]
521    fn test_from_str_auto() {
522        let sel = DeviceSelection::from_str("auto").unwrap();
523        // auto always returns a valid device; on any platform CPU is the fallback
524        assert!(
525            sel.kind == DeviceKind::Cpu
526                || sel.kind == DeviceKind::Cuda
527                || sel.kind == DeviceKind::CoreML
528                || sel.kind == DeviceKind::DirectML
529        );
530    }
531
532    #[test]
533    fn test_from_str_case_insensitive() {
534        let sel = DeviceSelection::from_str("CUDA").unwrap();
535        assert_eq!(sel.kind, DeviceKind::Cuda);
536        assert_eq!(sel.device_id, 0);
537
538        let sel2 = DeviceSelection::from_str("Cuda:1").unwrap();
539        assert_eq!(sel2.kind, DeviceKind::Cuda);
540        assert_eq!(sel2.device_id, 1);
541
542        let sel3 = DeviceSelection::from_str("CPU").unwrap();
543        assert_eq!(sel3.kind, DeviceKind::Cpu);
544
545        let sel4 = DeviceSelection::from_str("CoreML").unwrap();
546        assert_eq!(sel4.kind, DeviceKind::CoreML);
547    }
548
549    // --- Error cases ---
550
551    #[test]
552    fn test_from_str_invalid() {
553        let err = DeviceSelection::from_str("invalid");
554        assert!(err.is_err());
555    }
556
557    #[test]
558    fn test_from_str_gpu_unknown() {
559        let err = DeviceSelection::from_str("gpu");
560        assert!(err.is_err());
561    }
562
563    #[test]
564    fn test_from_str_empty() {
565        let err = DeviceSelection::from_str("");
566        assert!(err.is_err());
567    }
568
569    #[test]
570    fn test_from_str_bad_device_id() {
571        let err = DeviceSelection::from_str("cuda:abc");
572        assert!(err.is_err());
573    }
574
575    // --- Constructors ---
576
577    #[test]
578    fn test_constructor_cpu() {
579        let sel = DeviceSelection::cpu();
580        assert_eq!(sel.kind, DeviceKind::Cpu);
581        assert_eq!(sel.device_id, 0);
582    }
583
584    #[test]
585    fn test_constructor_cuda() {
586        let sel = DeviceSelection::cuda(3);
587        assert_eq!(sel.kind, DeviceKind::Cuda);
588        assert_eq!(sel.device_id, 3);
589    }
590
591    #[test]
592    fn test_constructor_coreml() {
593        let sel = DeviceSelection::coreml();
594        assert_eq!(sel.kind, DeviceKind::CoreML);
595        assert_eq!(sel.device_id, 0);
596    }
597
598    #[test]
599    fn test_constructor_directml() {
600        let sel = DeviceSelection::directml(1);
601        assert_eq!(sel.kind, DeviceKind::DirectML);
602        assert_eq!(sel.device_id, 1);
603    }
604
605    // --- DeviceKind Display ---
606
607    #[test]
608    fn test_device_kind_display() {
609        assert_eq!(DeviceKind::Cpu.to_string(), "cpu");
610        assert_eq!(DeviceKind::Cuda.to_string(), "cuda");
611        assert_eq!(DeviceKind::CoreML.to_string(), "coreml");
612        assert_eq!(DeviceKind::DirectML.to_string(), "directml");
613        assert_eq!(DeviceKind::TensorRT.to_string(), "tensorrt");
614    }
615
616    // --- DeviceInfo Display ---
617
618    #[test]
619    fn test_device_info_display_cpu() {
620        let info = DeviceInfo {
621            kind: DeviceKind::Cpu,
622            device_id: 0,
623            name: "CPU".to_string(),
624            available: true,
625            memory_bytes: None,
626        };
627        let s = info.to_string();
628        assert_eq!(s, "cpu (CPU) [available]");
629    }
630
631    #[test]
632    fn test_device_info_display_cuda_with_memory() {
633        let info = DeviceInfo {
634            kind: DeviceKind::Cuda,
635            device_id: 0,
636            name: "NVIDIA GeForce RTX 3090".to_string(),
637            available: true,
638            memory_bytes: Some(24 * 1024 * 1024 * 1024), // 24 GB
639        };
640        let s = info.to_string();
641        assert_eq!(s, "cuda:0 (NVIDIA GeForce RTX 3090, 24GB) [available]");
642    }
643
644    #[test]
645    fn test_device_info_display_unavailable() {
646        let info = DeviceInfo {
647            kind: DeviceKind::Cuda,
648            device_id: 1,
649            name: "CUDA Device 1".to_string(),
650            available: false,
651            memory_bytes: None,
652        };
653        let s = info.to_string();
654        assert_eq!(s, "cuda:1 (CUDA Device 1) [unavailable]");
655    }
656
657    // --- enumerate_devices ---
658
659    #[test]
660    fn test_enumerate_devices_always_includes_cpu() {
661        let devices = enumerate_devices();
662        assert!(!devices.is_empty());
663        assert!(devices.iter().any(|d| d.kind == DeviceKind::Cpu));
664        // CPU must be available
665        let cpu = devices.iter().find(|d| d.kind == DeviceKind::Cpu).unwrap();
666        assert!(cpu.available);
667    }
668
669    // --- is_device_available ---
670
671    #[test]
672    fn test_cpu_always_available() {
673        assert!(is_device_available(&DeviceKind::Cpu));
674    }
675
676    // --- auto / recommended ---
677
678    #[test]
679    fn test_auto_returns_valid_device() {
680        let sel = DeviceSelection::auto();
681        // Must be one of the known device kinds
682        assert!(
683            sel.kind == DeviceKind::Cpu
684                || sel.kind == DeviceKind::Cuda
685                || sel.kind == DeviceKind::CoreML
686                || sel.kind == DeviceKind::DirectML
687        );
688        assert!(sel.device_id >= 0);
689    }
690
691    #[test]
692    fn test_recommended_device_returns_valid() {
693        let sel = recommended_device();
694        assert!(
695            sel.kind == DeviceKind::Cpu
696                || sel.kind == DeviceKind::Cuda
697                || sel.kind == DeviceKind::CoreML
698                || sel.kind == DeviceKind::DirectML
699        );
700        assert!(sel.device_id >= 0);
701    }
702
703    // --- DeviceSelection Display ---
704
705    #[test]
706    fn test_device_selection_display_cpu() {
707        let sel = DeviceSelection::cpu();
708        assert_eq!(sel.to_string(), "cpu");
709    }
710
711    #[test]
712    fn test_device_selection_display_cuda() {
713        let sel = DeviceSelection::cuda(1);
714        assert_eq!(sel.to_string(), "cuda:1");
715    }
716
717    // --- DeviceKind equality / Hash ---
718
719    #[test]
720    fn test_device_kind_eq_and_hash() {
721        use std::collections::HashSet;
722        let mut set = HashSet::new();
723        set.insert(DeviceKind::Cpu);
724        set.insert(DeviceKind::Cuda);
725        set.insert(DeviceKind::Cpu); // duplicate
726        assert_eq!(set.len(), 2);
727        assert!(set.contains(&DeviceKind::Cpu));
728        assert!(set.contains(&DeviceKind::Cuda));
729        assert!(!set.contains(&DeviceKind::CoreML));
730    }
731
732    // -----------------------------------------------------------------------
733    // Additional TDD tests
734    // -----------------------------------------------------------------------
735
736    #[test]
737    fn test_device_selection_from_str_negative_id() {
738        // "cuda:-1" must be rejected -- negative device IDs are not allowed.
739        let result = DeviceSelection::from_str("cuda:-1");
740        assert!(result.is_err());
741        let err_msg = result.unwrap_err().to_string();
742        assert!(
743            err_msg.contains("negative device ID"),
744            "error should mention negative device ID, got: {err_msg}"
745        );
746    }
747
748    #[test]
749    fn test_device_selection_from_str_cpu_with_id_rejected() {
750        // "cpu:1" must be rejected -- cpu does not accept a device ID.
751        let result = DeviceSelection::from_str("cpu:1");
752        assert!(result.is_err());
753        let err_msg = result.unwrap_err().to_string();
754        assert!(
755            err_msg.contains("cpu does not accept a device ID"),
756            "error should mention cpu device ID, got: {err_msg}"
757        );
758    }
759
760    #[test]
761    fn test_device_selection_from_str_cpu_zero_ok() {
762        // "cpu:0" is accepted (equivalent to bare "cpu").
763        let sel = DeviceSelection::from_str("cpu:0").unwrap();
764        assert_eq!(sel.kind, DeviceKind::Cpu);
765        assert_eq!(sel.device_id, 0);
766    }
767
768    #[test]
769    fn test_device_selection_from_str_coreml_with_id_rejected() {
770        // "coreml:1" must be rejected -- coreml does not accept a device ID.
771        let result = DeviceSelection::from_str("coreml:1");
772        assert!(result.is_err());
773        let err_msg = result.unwrap_err().to_string();
774        assert!(
775            err_msg.contains("coreml does not accept a device ID"),
776            "error should mention coreml device ID, got: {err_msg}"
777        );
778    }
779
780    #[test]
781    fn test_device_selection_from_str_coreml_zero_ok() {
782        // "coreml:0" is accepted (equivalent to bare "coreml").
783        let sel = DeviceSelection::from_str("coreml:0").unwrap();
784        assert_eq!(sel.kind, DeviceKind::CoreML);
785        assert_eq!(sel.device_id, 0);
786    }
787
788    #[test]
789    fn test_device_selection_display_roundtrip() {
790        // Display then parse back should produce the same value.
791        let cases = vec![
792            DeviceSelection::cpu(),
793            DeviceSelection::cuda(0),
794            DeviceSelection::cuda(3),
795            DeviceSelection::coreml(),
796            DeviceSelection::directml(0),
797            DeviceSelection::directml(2),
798        ];
799        for sel in cases {
800            let displayed = sel.to_string();
801            let parsed = DeviceSelection::from_str(&displayed).unwrap();
802            assert_eq!(
803                parsed.kind, sel.kind,
804                "roundtrip kind failed for '{displayed}'"
805            );
806            assert_eq!(
807                parsed.device_id, sel.device_id,
808                "roundtrip id failed for '{displayed}'"
809            );
810        }
811    }
812
813    #[test]
814    fn test_enumerate_devices_no_duplicates() {
815        let devices = enumerate_devices();
816        let mut seen_kinds: Vec<DeviceKind> = Vec::new();
817        for d in devices {
818            assert!(
819                !seen_kinds.contains(&d.kind),
820                "duplicate device kind: {:?}",
821                d.kind
822            );
823            seen_kinds.push(d.kind.clone());
824        }
825    }
826
827    #[test]
828    fn test_device_info_memory_display_large() {
829        // 80 GB VRAM (A100-class) -- verify no overflow in display formatting
830        let memory: u64 = 80 * 1024 * 1024 * 1024;
831        let info = DeviceInfo {
832            kind: DeviceKind::Cuda,
833            device_id: 0,
834            name: "NVIDIA A100".to_string(),
835            available: true,
836            memory_bytes: Some(memory),
837        };
838        let s = info.to_string();
839        assert!(s.contains("80GB"), "expected '80GB' in: {s}");
840        assert!(s.contains("[available]"));
841        assert!(s.contains("cuda:0"));
842    }
843}