Skip to main content

piper_plus/
gpu.rs

1//! Low-level GPU inference support via ONNX Runtime ExecutionProviders.
2//!
3//! This module handles the **ort integration layer** -- it configures ONNX
4//! Runtime `SessionBuilder` instances with the appropriate `ExecutionProvider`
5//! (CUDA, CoreML, DirectML, TensorRT) and manages device string parsing for
6//! the engine.
7//!
8//! Feature-gated: `cuda`, `coreml`, `directml`, `tensorrt` features enable
9//! respective providers.  Auto-detection tries available providers and falls
10//! back to CPU.
11//!
12//! For the high-level, user-facing device enumeration and selection API, see
13//! [`crate::device`].
14
15use crate::error::PiperError;
16
17/// Supported GPU device types.
18#[derive(Debug, Clone, PartialEq)]
19pub enum DeviceType {
20    Cpu,
21    Cuda { device_id: i32 },
22    CoreML,
23    DirectML { device_id: i32 },
24    TensorRT { device_id: i32 },
25}
26
27impl std::fmt::Display for DeviceType {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            DeviceType::Cpu => write!(f, "cpu"),
31            DeviceType::Cuda { device_id } => write!(f, "cuda:{device_id}"),
32            DeviceType::CoreML => write!(f, "coreml"),
33            DeviceType::DirectML { device_id } => write!(f, "directml:{device_id}"),
34            DeviceType::TensorRT { device_id } => write!(f, "tensorrt:{device_id}"),
35        }
36    }
37}
38
39/// Information about an available compute device.
40#[derive(Debug, Clone)]
41pub struct DeviceInfo {
42    pub name: String,
43    pub device_type: DeviceType,
44    pub available: bool,
45}
46
47/// Parse a device string ("cpu", "cuda", "cuda:0", "cuda:1", "coreml",
48/// "directml", "directml:2", "tensorrt", "tensorrt:0", "auto") into a
49/// [`DeviceType`].
50///
51/// The string is matched case-insensitively.
52pub fn parse_device_string(device: &str) -> Result<DeviceType, PiperError> {
53    let device = device.trim();
54
55    if device.eq_ignore_ascii_case("cpu") {
56        return Ok(DeviceType::Cpu);
57    }
58
59    if device.eq_ignore_ascii_case("auto") {
60        return Ok(auto_detect_device());
61    }
62
63    if device.eq_ignore_ascii_case("coreml") {
64        return Ok(DeviceType::CoreML);
65    }
66
67    // Handle devices with optional ":N" suffix by splitting on ':'
68    if let Some((prefix, id_str)) = device.split_once(':') {
69        // Map prefix to canonical display name for error messages
70        let canonical = if prefix.eq_ignore_ascii_case("cuda") {
71            Some("CUDA")
72        } else if prefix.eq_ignore_ascii_case("directml") {
73            Some("DirectML")
74        } else if prefix.eq_ignore_ascii_case("tensorrt") {
75            Some("TensorRT")
76        } else {
77            None
78        };
79
80        if let Some(kind_name) = canonical {
81            let device_id = id_str
82                .parse::<i32>()
83                .map_err(|_| PiperError::InvalidConfig {
84                    reason: format!("invalid {kind_name} device id: '{id_str}'"),
85                })?;
86
87            if device_id < 0 {
88                return Err(PiperError::InvalidConfig {
89                    reason: format!("negative device ID not allowed: {device_id}"),
90                });
91            }
92
93            return match kind_name {
94                "CUDA" => Ok(DeviceType::Cuda { device_id }),
95                "DirectML" => Ok(DeviceType::DirectML { device_id }),
96                "TensorRT" => Ok(DeviceType::TensorRT { device_id }),
97                _ => unreachable!(),
98            };
99        }
100    } else {
101        // No colon -- bare name defaults to device_id 0
102        if device.eq_ignore_ascii_case("cuda") {
103            return Ok(DeviceType::Cuda { device_id: 0 });
104        }
105        if device.eq_ignore_ascii_case("directml") {
106            return Ok(DeviceType::DirectML { device_id: 0 });
107        }
108        if device.eq_ignore_ascii_case("tensorrt") {
109            return Ok(DeviceType::TensorRT { device_id: 0 });
110        }
111    }
112
113    Err(PiperError::InvalidConfig {
114        reason: format!("unknown device: '{device}'"),
115    })
116}
117
118/// Auto-detect the best available device.
119///
120/// Priority: CUDA -> CoreML -> DirectML -> CPU.
121/// Only checks providers whose corresponding feature is enabled.
122fn auto_detect_device() -> DeviceType {
123    #[cfg(feature = "cuda")]
124    {
125        if is_cuda_available() {
126            tracing::info!("Auto-detected CUDA device");
127            return DeviceType::Cuda { device_id: 0 };
128        }
129    }
130
131    #[cfg(feature = "coreml")]
132    {
133        if is_coreml_available() {
134            tracing::info!("Auto-detected CoreML device");
135            return DeviceType::CoreML;
136        }
137    }
138
139    #[cfg(feature = "directml")]
140    {
141        if is_directml_available() {
142            tracing::info!("Auto-detected DirectML device");
143            return DeviceType::DirectML { device_id: 0 };
144        }
145    }
146
147    tracing::info!("No GPU providers available, using CPU");
148    DeviceType::Cpu
149}
150
151/// List all available compute devices.
152///
153/// Always includes CPU. Checks for CUDA/CoreML/DirectML/TensorRT availability
154/// based on enabled features.
155pub fn list_devices() -> Vec<DeviceInfo> {
156    let mut devices = Vec::new();
157
158    // CPU is always available
159    devices.push(DeviceInfo {
160        name: "CPU".to_string(),
161        device_type: DeviceType::Cpu,
162        available: true,
163    });
164
165    #[cfg(feature = "cuda")]
166    {
167        let available = is_cuda_available();
168        devices.push(DeviceInfo {
169            name: "CUDA".to_string(),
170            device_type: DeviceType::Cuda { device_id: 0 },
171            available,
172        });
173    }
174
175    #[cfg(feature = "coreml")]
176    {
177        let available = is_coreml_available();
178        devices.push(DeviceInfo {
179            name: "CoreML".to_string(),
180            device_type: DeviceType::CoreML,
181            available,
182        });
183    }
184
185    #[cfg(feature = "directml")]
186    {
187        let available = is_directml_available();
188        devices.push(DeviceInfo {
189            name: "DirectML".to_string(),
190            device_type: DeviceType::DirectML { device_id: 0 },
191            available,
192        });
193    }
194
195    #[cfg(feature = "tensorrt")]
196    {
197        let available = is_tensorrt_available();
198        devices.push(DeviceInfo {
199            name: "TensorRT".to_string(),
200            device_type: DeviceType::TensorRT { device_id: 0 },
201            available,
202        });
203    }
204
205    devices
206}
207
208/// Configure an ONNX Runtime session builder with the appropriate ExecutionProvider.
209///
210/// Returns the builder and the device actually used (may fall back to CPU if the
211/// requested provider is unavailable or registration fails).
212///
213/// Uses ort v2 API:
214/// ```ignore
215/// use ort::ep;
216/// let builder = Session::builder()?
217///     .with_execution_providers([ep::CUDA::default().build()])?;
218/// ```
219pub fn configure_session_builder(
220    builder: ort::session::builder::SessionBuilder,
221    device: &DeviceType,
222) -> Result<(ort::session::builder::SessionBuilder, DeviceType), PiperError> {
223    match device {
224        DeviceType::Cpu => Ok((builder, DeviceType::Cpu)),
225
226        #[cfg(feature = "cuda")]
227        DeviceType::Cuda { device_id } => configure_cuda(builder, *device_id),
228        #[cfg(not(feature = "cuda"))]
229        DeviceType::Cuda { .. } => {
230            tracing::warn!("CUDA requested but 'cuda' feature is not enabled, falling back to CPU");
231            Ok((builder, DeviceType::Cpu))
232        }
233
234        #[cfg(feature = "coreml")]
235        DeviceType::CoreML => configure_coreml(builder),
236        #[cfg(not(feature = "coreml"))]
237        DeviceType::CoreML => {
238            tracing::warn!(
239                "CoreML requested but 'coreml' feature is not enabled, falling back to CPU"
240            );
241            Ok((builder, DeviceType::Cpu))
242        }
243
244        #[cfg(feature = "directml")]
245        DeviceType::DirectML { device_id } => configure_directml(builder, *device_id),
246        #[cfg(not(feature = "directml"))]
247        DeviceType::DirectML { .. } => {
248            tracing::warn!(
249                "DirectML requested but 'directml' feature is not enabled, falling back to CPU"
250            );
251            Ok((builder, DeviceType::Cpu))
252        }
253
254        #[cfg(feature = "tensorrt")]
255        DeviceType::TensorRT { device_id } => configure_tensorrt(builder, *device_id),
256        #[cfg(not(feature = "tensorrt"))]
257        DeviceType::TensorRT { .. } => {
258            tracing::warn!(
259                "TensorRT requested but 'tensorrt' feature is not enabled, falling back to CPU"
260            );
261            Ok((builder, DeviceType::Cpu))
262        }
263    }
264}
265
266// ---------------------------------------------------------------------------
267// Feature-gated provider helpers
268// ---------------------------------------------------------------------------
269
270#[cfg(feature = "cuda")]
271fn is_cuda_available() -> bool {
272    use ort::ep::{CUDA, ExecutionProvider};
273    CUDA::default().is_available().unwrap_or(false)
274}
275
276#[cfg(feature = "cuda")]
277fn configure_cuda(
278    builder: ort::session::builder::SessionBuilder,
279    device_id: i32,
280) -> Result<(ort::session::builder::SessionBuilder, DeviceType), PiperError> {
281    let ep = ort::ep::CUDA::default().with_device_id(device_id).build();
282    match builder.with_execution_providers([ep]) {
283        Ok(b) => {
284            tracing::info!("CUDA execution provider registered (device_id={device_id})");
285            Ok((b, DeviceType::Cuda { device_id }))
286        }
287        Err(e) => {
288            tracing::warn!("Failed to register CUDA EP: {e}, falling back to CPU");
289            let recovered = e.recover();
290            Ok((recovered, DeviceType::Cpu))
291        }
292    }
293}
294
295#[cfg(feature = "coreml")]
296fn is_coreml_available() -> bool {
297    use ort::ep::{CoreML, ExecutionProvider};
298    CoreML::default().is_available().unwrap_or(false)
299}
300
301#[cfg(feature = "coreml")]
302fn configure_coreml(
303    builder: ort::session::builder::SessionBuilder,
304) -> Result<(ort::session::builder::SessionBuilder, DeviceType), PiperError> {
305    let ep = ort::ep::CoreML::default().build();
306    match builder.with_execution_providers([ep]) {
307        Ok(b) => {
308            tracing::info!("CoreML execution provider registered");
309            Ok((b, DeviceType::CoreML))
310        }
311        Err(e) => {
312            tracing::warn!("Failed to register CoreML EP: {e}, falling back to CPU");
313            let recovered = e.recover();
314            Ok((recovered, DeviceType::Cpu))
315        }
316    }
317}
318
319#[cfg(feature = "directml")]
320fn is_directml_available() -> bool {
321    use ort::ep::{DirectML, ExecutionProvider};
322    DirectML::default().is_available().unwrap_or(false)
323}
324
325#[cfg(feature = "directml")]
326fn configure_directml(
327    builder: ort::session::builder::SessionBuilder,
328    device_id: i32,
329) -> Result<(ort::session::builder::SessionBuilder, DeviceType), PiperError> {
330    let ep = ort::ep::DirectML::default()
331        .with_device_id(device_id)
332        .build();
333    match builder.with_execution_providers([ep]) {
334        Ok(b) => {
335            tracing::info!("DirectML execution provider registered (device_id={device_id})");
336            Ok((b, DeviceType::DirectML { device_id }))
337        }
338        Err(e) => {
339            tracing::warn!("Failed to register DirectML EP: {e}, falling back to CPU");
340            let recovered = e.recover();
341            Ok((recovered, DeviceType::Cpu))
342        }
343    }
344}
345
346#[cfg(feature = "tensorrt")]
347fn is_tensorrt_available() -> bool {
348    use ort::ep::{ExecutionProvider, TensorRT};
349    TensorRT::default().is_available().unwrap_or(false)
350}
351
352#[cfg(feature = "tensorrt")]
353fn configure_tensorrt(
354    builder: ort::session::builder::SessionBuilder,
355    device_id: i32,
356) -> Result<(ort::session::builder::SessionBuilder, DeviceType), PiperError> {
357    let ep = ort::ep::TensorRT::default()
358        .with_device_id(device_id)
359        .build();
360    match builder.with_execution_providers([ep]) {
361        Ok(b) => {
362            tracing::info!("TensorRT execution provider registered (device_id={device_id})");
363            Ok((b, DeviceType::TensorRT { device_id }))
364        }
365        Err(e) => {
366            tracing::warn!("Failed to register TensorRT EP: {e}, falling back to CPU");
367            let recovered = e.recover();
368            Ok((recovered, DeviceType::Cpu))
369        }
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    // -----------------------------------------------------------------------
378    // parse_device_string tests
379    // -----------------------------------------------------------------------
380
381    #[test]
382    fn test_parse_cpu() {
383        let dt = parse_device_string("cpu").unwrap();
384        assert_eq!(dt, DeviceType::Cpu);
385    }
386
387    #[test]
388    fn test_parse_cpu_uppercase() {
389        let dt = parse_device_string("CPU").unwrap();
390        assert_eq!(dt, DeviceType::Cpu);
391    }
392
393    #[test]
394    fn test_parse_cuda_default() {
395        let dt = parse_device_string("cuda").unwrap();
396        assert_eq!(dt, DeviceType::Cuda { device_id: 0 });
397    }
398
399    #[test]
400    fn test_parse_cuda_device_0() {
401        let dt = parse_device_string("cuda:0").unwrap();
402        assert_eq!(dt, DeviceType::Cuda { device_id: 0 });
403    }
404
405    #[test]
406    fn test_parse_cuda_device_1() {
407        let dt = parse_device_string("cuda:1").unwrap();
408        assert_eq!(dt, DeviceType::Cuda { device_id: 1 });
409    }
410
411    #[test]
412    fn test_parse_cuda_mixed_case() {
413        let dt = parse_device_string("CUDA:2").unwrap();
414        assert_eq!(dt, DeviceType::Cuda { device_id: 2 });
415    }
416
417    #[test]
418    fn test_parse_coreml() {
419        let dt = parse_device_string("coreml").unwrap();
420        assert_eq!(dt, DeviceType::CoreML);
421    }
422
423    #[test]
424    fn test_parse_coreml_uppercase() {
425        let dt = parse_device_string("CoreML").unwrap();
426        assert_eq!(dt, DeviceType::CoreML);
427    }
428
429    #[test]
430    fn test_parse_directml_default() {
431        let dt = parse_device_string("directml").unwrap();
432        assert_eq!(dt, DeviceType::DirectML { device_id: 0 });
433    }
434
435    #[test]
436    fn test_parse_directml_device_2() {
437        let dt = parse_device_string("directml:2").unwrap();
438        assert_eq!(dt, DeviceType::DirectML { device_id: 2 });
439    }
440
441    #[test]
442    fn test_parse_tensorrt_default() {
443        let dt = parse_device_string("tensorrt").unwrap();
444        assert_eq!(dt, DeviceType::TensorRT { device_id: 0 });
445    }
446
447    #[test]
448    fn test_parse_tensorrt_device_0() {
449        let dt = parse_device_string("tensorrt:0").unwrap();
450        assert_eq!(dt, DeviceType::TensorRT { device_id: 0 });
451    }
452
453    #[test]
454    fn test_parse_auto() {
455        // Auto should always succeed (falls back to CPU when no GPU features enabled)
456        let dt = parse_device_string("auto").unwrap();
457        // Without GPU features, auto resolves to CPU
458        #[cfg(not(any(feature = "cuda", feature = "coreml", feature = "directml")))]
459        assert_eq!(dt, DeviceType::Cpu);
460        // With any GPU feature, it may resolve to a GPU device (still valid)
461        #[cfg(any(feature = "cuda", feature = "coreml", feature = "directml"))]
462        let _ = dt; // just ensure no error
463    }
464
465    // -----------------------------------------------------------------------
466    // parse_device_string error cases
467    // -----------------------------------------------------------------------
468
469    #[test]
470    fn test_parse_invalid_device() {
471        let result = parse_device_string("vulkan");
472        assert!(result.is_err());
473        let err_msg = format!("{}", result.unwrap_err());
474        assert!(err_msg.contains("unknown device"));
475    }
476
477    #[test]
478    fn test_parse_cuda_invalid_id() {
479        let result = parse_device_string("cuda:abc");
480        assert!(result.is_err());
481        let err_msg = format!("{}", result.unwrap_err());
482        assert!(err_msg.contains("invalid CUDA device id"));
483    }
484
485    #[test]
486    fn test_parse_directml_invalid_id() {
487        let result = parse_device_string("directml:xyz");
488        assert!(result.is_err());
489        let err_msg = format!("{}", result.unwrap_err());
490        assert!(err_msg.contains("invalid DirectML device id"));
491    }
492
493    #[test]
494    fn test_parse_tensorrt_invalid_id() {
495        let result = parse_device_string("tensorrt:bad");
496        assert!(result.is_err());
497        let err_msg = format!("{}", result.unwrap_err());
498        assert!(err_msg.contains("invalid TensorRT device id"));
499    }
500
501    #[test]
502    fn test_parse_empty_string() {
503        let result = parse_device_string("");
504        assert!(result.is_err());
505    }
506
507    // -----------------------------------------------------------------------
508    // list_devices tests
509    // -----------------------------------------------------------------------
510
511    #[test]
512    fn test_list_devices_contains_cpu() {
513        let devices = list_devices();
514        assert!(!devices.is_empty());
515        assert!(devices.iter().any(|d| d.device_type == DeviceType::Cpu));
516    }
517
518    #[test]
519    fn test_list_devices_cpu_always_available() {
520        let devices = list_devices();
521        let cpu = devices
522            .iter()
523            .find(|d| d.device_type == DeviceType::Cpu)
524            .unwrap();
525        assert!(cpu.available);
526        assert_eq!(cpu.name, "CPU");
527    }
528
529    #[test]
530    fn test_list_devices_first_is_cpu() {
531        let devices = list_devices();
532        assert_eq!(devices[0].device_type, DeviceType::Cpu);
533    }
534
535    // -----------------------------------------------------------------------
536    // DeviceType Display tests
537    // -----------------------------------------------------------------------
538
539    #[test]
540    fn test_display_cpu() {
541        assert_eq!(format!("{}", DeviceType::Cpu), "cpu");
542    }
543
544    #[test]
545    fn test_display_cuda() {
546        assert_eq!(format!("{}", DeviceType::Cuda { device_id: 0 }), "cuda:0");
547        assert_eq!(format!("{}", DeviceType::Cuda { device_id: 3 }), "cuda:3");
548    }
549
550    #[test]
551    fn test_display_coreml() {
552        assert_eq!(format!("{}", DeviceType::CoreML), "coreml");
553    }
554
555    #[test]
556    fn test_display_directml() {
557        assert_eq!(
558            format!("{}", DeviceType::DirectML { device_id: 1 }),
559            "directml:1"
560        );
561    }
562
563    #[test]
564    fn test_display_tensorrt() {
565        assert_eq!(
566            format!("{}", DeviceType::TensorRT { device_id: 0 }),
567            "tensorrt:0"
568        );
569    }
570
571    // -----------------------------------------------------------------------
572    // DeviceInfo tests
573    // -----------------------------------------------------------------------
574
575    #[test]
576    fn test_device_info_construction() {
577        let info = DeviceInfo {
578            name: "TestGPU".to_string(),
579            device_type: DeviceType::Cuda { device_id: 1 },
580            available: true,
581        };
582        assert_eq!(info.name, "TestGPU");
583        assert_eq!(info.device_type, DeviceType::Cuda { device_id: 1 });
584        assert!(info.available);
585    }
586
587    #[test]
588    fn test_device_info_debug() {
589        let info = DeviceInfo {
590            name: "CPU".to_string(),
591            device_type: DeviceType::Cpu,
592            available: true,
593        };
594        let debug = format!("{:?}", info);
595        assert!(debug.contains("CPU"));
596        assert!(debug.contains("available: true"));
597    }
598
599    #[test]
600    fn test_device_info_clone() {
601        let info = DeviceInfo {
602            name: "CUDA".to_string(),
603            device_type: DeviceType::Cuda { device_id: 0 },
604            available: false,
605        };
606        let cloned = info.clone();
607        assert_eq!(cloned.name, info.name);
608        assert_eq!(cloned.device_type, info.device_type);
609        assert_eq!(cloned.available, info.available);
610    }
611
612    // -----------------------------------------------------------------------
613    // DeviceType equality and clone tests
614    // -----------------------------------------------------------------------
615
616    #[test]
617    fn test_device_type_equality() {
618        assert_eq!(DeviceType::Cpu, DeviceType::Cpu);
619        assert_eq!(
620            DeviceType::Cuda { device_id: 0 },
621            DeviceType::Cuda { device_id: 0 }
622        );
623        assert_ne!(
624            DeviceType::Cuda { device_id: 0 },
625            DeviceType::Cuda { device_id: 1 }
626        );
627        assert_ne!(DeviceType::Cpu, DeviceType::CoreML);
628    }
629
630    #[test]
631    fn test_device_type_clone() {
632        let dt = DeviceType::TensorRT { device_id: 2 };
633        let cloned = dt.clone();
634        assert_eq!(dt, cloned);
635    }
636
637    // -----------------------------------------------------------------------
638    // Feature-gated availability tests
639    // -----------------------------------------------------------------------
640
641    #[cfg(feature = "cuda")]
642    #[test]
643    fn test_cuda_listed_when_feature_enabled() {
644        let devices = list_devices();
645        assert!(
646            devices
647                .iter()
648                .any(|d| matches!(d.device_type, DeviceType::Cuda { .. }))
649        );
650    }
651
652    #[cfg(feature = "coreml")]
653    #[test]
654    fn test_coreml_listed_when_feature_enabled() {
655        let devices = list_devices();
656        assert!(devices.iter().any(|d| d.device_type == DeviceType::CoreML));
657    }
658
659    #[cfg(feature = "directml")]
660    #[test]
661    fn test_directml_listed_when_feature_enabled() {
662        let devices = list_devices();
663        assert!(
664            devices
665                .iter()
666                .any(|d| matches!(d.device_type, DeviceType::DirectML { .. }))
667        );
668    }
669
670    #[cfg(feature = "tensorrt")]
671    #[test]
672    fn test_tensorrt_listed_when_feature_enabled() {
673        let devices = list_devices();
674        assert!(
675            devices
676                .iter()
677                .any(|d| matches!(d.device_type, DeviceType::TensorRT { .. }))
678        );
679    }
680
681    // -----------------------------------------------------------------------
682    // configure_session_builder CPU test
683    // -----------------------------------------------------------------------
684
685    #[test]
686    fn test_configure_cpu_returns_cpu() {
687        // CPU configuration should always succeed without needing an actual model
688        let builder = ort::session::Session::builder().expect("session builder");
689        let (_, actual_device) = configure_session_builder(builder, &DeviceType::Cpu).unwrap();
690        assert_eq!(actual_device, DeviceType::Cpu);
691    }
692
693    // -----------------------------------------------------------------------
694    // Fallback tests (feature not enabled)
695    // -----------------------------------------------------------------------
696
697    #[cfg(not(feature = "cuda"))]
698    #[test]
699    fn test_cuda_fallback_without_feature() {
700        let builder = ort::session::Session::builder().expect("session builder");
701        let (_, actual_device) =
702            configure_session_builder(builder, &DeviceType::Cuda { device_id: 0 }).unwrap();
703        assert_eq!(actual_device, DeviceType::Cpu);
704    }
705
706    #[cfg(not(feature = "coreml"))]
707    #[test]
708    fn test_coreml_fallback_without_feature() {
709        let builder = ort::session::Session::builder().expect("session builder");
710        let (_, actual_device) = configure_session_builder(builder, &DeviceType::CoreML).unwrap();
711        assert_eq!(actual_device, DeviceType::Cpu);
712    }
713
714    #[cfg(not(feature = "directml"))]
715    #[test]
716    fn test_directml_fallback_without_feature() {
717        let builder = ort::session::Session::builder().expect("session builder");
718        let (_, actual_device) =
719            configure_session_builder(builder, &DeviceType::DirectML { device_id: 0 }).unwrap();
720        assert_eq!(actual_device, DeviceType::Cpu);
721    }
722
723    #[cfg(not(feature = "tensorrt"))]
724    #[test]
725    fn test_tensorrt_fallback_without_feature() {
726        let builder = ort::session::Session::builder().expect("session builder");
727        let (_, actual_device) =
728            configure_session_builder(builder, &DeviceType::TensorRT { device_id: 0 }).unwrap();
729        assert_eq!(actual_device, DeviceType::Cpu);
730    }
731
732    // -----------------------------------------------------------------------
733    // Additional TDD tests
734    // -----------------------------------------------------------------------
735
736    #[test]
737    fn test_auto_detect_device_returns_valid() {
738        let dt = parse_device_string("auto").unwrap();
739        // Regardless of features, the result must be a valid DeviceType variant
740        match dt {
741            DeviceType::Cpu
742            | DeviceType::Cuda { .. }
743            | DeviceType::CoreML
744            | DeviceType::DirectML { .. }
745            | DeviceType::TensorRT { .. } => {} // all valid
746        }
747    }
748
749    #[test]
750    fn test_parse_device_string_whitespace() {
751        // parse_device_string trims whitespace before matching.
752        let dt = parse_device_string(" cuda ").unwrap();
753        assert_eq!(dt, DeviceType::Cuda { device_id: 0 });
754    }
755
756    #[test]
757    fn test_parse_device_string_large_device_id() {
758        let dt = parse_device_string("cuda:999").unwrap();
759        assert_eq!(dt, DeviceType::Cuda { device_id: 999 });
760    }
761
762    #[test]
763    fn test_device_type_default_display_roundtrip() {
764        // For each variant, Display then parse back should produce the same value
765        let variants = vec![
766            DeviceType::Cpu,
767            DeviceType::Cuda { device_id: 0 },
768            DeviceType::Cuda { device_id: 7 },
769            DeviceType::CoreML,
770            DeviceType::DirectML { device_id: 0 },
771            DeviceType::DirectML { device_id: 3 },
772            DeviceType::TensorRT { device_id: 0 },
773            DeviceType::TensorRT { device_id: 5 },
774        ];
775        for variant in variants {
776            let displayed = format!("{variant}");
777            let parsed = parse_device_string(&displayed).unwrap();
778            assert_eq!(parsed, variant, "roundtrip failed for '{displayed}'");
779        }
780    }
781}