1use crate::error::PiperError;
16
17#[derive(Debug, Clone, PartialEq)]
19pub enum DeviceType {
20 Cpu,
21 Cuda { device_id: i32 },
22 CoreML,
23 DirectML { device_id: i32 },
24 TensorRT { device_id: i32 },
25}
26
27impl std::fmt::Display for DeviceType {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 match self {
30 DeviceType::Cpu => write!(f, "cpu"),
31 DeviceType::Cuda { device_id } => write!(f, "cuda:{device_id}"),
32 DeviceType::CoreML => write!(f, "coreml"),
33 DeviceType::DirectML { device_id } => write!(f, "directml:{device_id}"),
34 DeviceType::TensorRT { device_id } => write!(f, "tensorrt:{device_id}"),
35 }
36 }
37}
38
39#[derive(Debug, Clone)]
41pub struct DeviceInfo {
42 pub name: String,
43 pub device_type: DeviceType,
44 pub available: bool,
45}
46
47pub fn parse_device_string(device: &str) -> Result<DeviceType, PiperError> {
53 let device = device.trim();
54
55 if device.eq_ignore_ascii_case("cpu") {
56 return Ok(DeviceType::Cpu);
57 }
58
59 if device.eq_ignore_ascii_case("auto") {
60 return Ok(auto_detect_device());
61 }
62
63 if device.eq_ignore_ascii_case("coreml") {
64 return Ok(DeviceType::CoreML);
65 }
66
67 if let Some((prefix, id_str)) = device.split_once(':') {
69 let canonical = if prefix.eq_ignore_ascii_case("cuda") {
71 Some("CUDA")
72 } else if prefix.eq_ignore_ascii_case("directml") {
73 Some("DirectML")
74 } else if prefix.eq_ignore_ascii_case("tensorrt") {
75 Some("TensorRT")
76 } else {
77 None
78 };
79
80 if let Some(kind_name) = canonical {
81 let device_id = id_str
82 .parse::<i32>()
83 .map_err(|_| PiperError::InvalidConfig {
84 reason: format!("invalid {kind_name} device id: '{id_str}'"),
85 })?;
86
87 if device_id < 0 {
88 return Err(PiperError::InvalidConfig {
89 reason: format!("negative device ID not allowed: {device_id}"),
90 });
91 }
92
93 return match kind_name {
94 "CUDA" => Ok(DeviceType::Cuda { device_id }),
95 "DirectML" => Ok(DeviceType::DirectML { device_id }),
96 "TensorRT" => Ok(DeviceType::TensorRT { device_id }),
97 _ => unreachable!(),
98 };
99 }
100 } else {
101 if device.eq_ignore_ascii_case("cuda") {
103 return Ok(DeviceType::Cuda { device_id: 0 });
104 }
105 if device.eq_ignore_ascii_case("directml") {
106 return Ok(DeviceType::DirectML { device_id: 0 });
107 }
108 if device.eq_ignore_ascii_case("tensorrt") {
109 return Ok(DeviceType::TensorRT { device_id: 0 });
110 }
111 }
112
113 Err(PiperError::InvalidConfig {
114 reason: format!("unknown device: '{device}'"),
115 })
116}
117
118fn auto_detect_device() -> DeviceType {
123 #[cfg(feature = "cuda")]
124 {
125 if is_cuda_available() {
126 tracing::info!("Auto-detected CUDA device");
127 return DeviceType::Cuda { device_id: 0 };
128 }
129 }
130
131 #[cfg(feature = "coreml")]
132 {
133 if is_coreml_available() {
134 tracing::info!("Auto-detected CoreML device");
135 return DeviceType::CoreML;
136 }
137 }
138
139 #[cfg(feature = "directml")]
140 {
141 if is_directml_available() {
142 tracing::info!("Auto-detected DirectML device");
143 return DeviceType::DirectML { device_id: 0 };
144 }
145 }
146
147 tracing::info!("No GPU providers available, using CPU");
148 DeviceType::Cpu
149}
150
151pub fn list_devices() -> Vec<DeviceInfo> {
156 let mut devices = Vec::new();
157
158 devices.push(DeviceInfo {
160 name: "CPU".to_string(),
161 device_type: DeviceType::Cpu,
162 available: true,
163 });
164
165 #[cfg(feature = "cuda")]
166 {
167 let available = is_cuda_available();
168 devices.push(DeviceInfo {
169 name: "CUDA".to_string(),
170 device_type: DeviceType::Cuda { device_id: 0 },
171 available,
172 });
173 }
174
175 #[cfg(feature = "coreml")]
176 {
177 let available = is_coreml_available();
178 devices.push(DeviceInfo {
179 name: "CoreML".to_string(),
180 device_type: DeviceType::CoreML,
181 available,
182 });
183 }
184
185 #[cfg(feature = "directml")]
186 {
187 let available = is_directml_available();
188 devices.push(DeviceInfo {
189 name: "DirectML".to_string(),
190 device_type: DeviceType::DirectML { device_id: 0 },
191 available,
192 });
193 }
194
195 #[cfg(feature = "tensorrt")]
196 {
197 let available = is_tensorrt_available();
198 devices.push(DeviceInfo {
199 name: "TensorRT".to_string(),
200 device_type: DeviceType::TensorRT { device_id: 0 },
201 available,
202 });
203 }
204
205 devices
206}
207
208pub fn configure_session_builder(
220 builder: ort::session::builder::SessionBuilder,
221 device: &DeviceType,
222) -> Result<(ort::session::builder::SessionBuilder, DeviceType), PiperError> {
223 match device {
224 DeviceType::Cpu => Ok((builder, DeviceType::Cpu)),
225
226 #[cfg(feature = "cuda")]
227 DeviceType::Cuda { device_id } => configure_cuda(builder, *device_id),
228 #[cfg(not(feature = "cuda"))]
229 DeviceType::Cuda { .. } => {
230 tracing::warn!("CUDA requested but 'cuda' feature is not enabled, falling back to CPU");
231 Ok((builder, DeviceType::Cpu))
232 }
233
234 #[cfg(feature = "coreml")]
235 DeviceType::CoreML => configure_coreml(builder),
236 #[cfg(not(feature = "coreml"))]
237 DeviceType::CoreML => {
238 tracing::warn!(
239 "CoreML requested but 'coreml' feature is not enabled, falling back to CPU"
240 );
241 Ok((builder, DeviceType::Cpu))
242 }
243
244 #[cfg(feature = "directml")]
245 DeviceType::DirectML { device_id } => configure_directml(builder, *device_id),
246 #[cfg(not(feature = "directml"))]
247 DeviceType::DirectML { .. } => {
248 tracing::warn!(
249 "DirectML requested but 'directml' feature is not enabled, falling back to CPU"
250 );
251 Ok((builder, DeviceType::Cpu))
252 }
253
254 #[cfg(feature = "tensorrt")]
255 DeviceType::TensorRT { device_id } => configure_tensorrt(builder, *device_id),
256 #[cfg(not(feature = "tensorrt"))]
257 DeviceType::TensorRT { .. } => {
258 tracing::warn!(
259 "TensorRT requested but 'tensorrt' feature is not enabled, falling back to CPU"
260 );
261 Ok((builder, DeviceType::Cpu))
262 }
263 }
264}
265
266#[cfg(feature = "cuda")]
271fn is_cuda_available() -> bool {
272 use ort::ep::{CUDA, ExecutionProvider};
273 CUDA::default().is_available().unwrap_or(false)
274}
275
276#[cfg(feature = "cuda")]
277fn configure_cuda(
278 builder: ort::session::builder::SessionBuilder,
279 device_id: i32,
280) -> Result<(ort::session::builder::SessionBuilder, DeviceType), PiperError> {
281 let ep = ort::ep::CUDA::default().with_device_id(device_id).build();
282 match builder.with_execution_providers([ep]) {
283 Ok(b) => {
284 tracing::info!("CUDA execution provider registered (device_id={device_id})");
285 Ok((b, DeviceType::Cuda { device_id }))
286 }
287 Err(e) => {
288 tracing::warn!("Failed to register CUDA EP: {e}, falling back to CPU");
289 let recovered = e.recover();
290 Ok((recovered, DeviceType::Cpu))
291 }
292 }
293}
294
295#[cfg(feature = "coreml")]
296fn is_coreml_available() -> bool {
297 use ort::ep::{CoreML, ExecutionProvider};
298 CoreML::default().is_available().unwrap_or(false)
299}
300
301#[cfg(feature = "coreml")]
302fn configure_coreml(
303 builder: ort::session::builder::SessionBuilder,
304) -> Result<(ort::session::builder::SessionBuilder, DeviceType), PiperError> {
305 let ep = ort::ep::CoreML::default().build();
306 match builder.with_execution_providers([ep]) {
307 Ok(b) => {
308 tracing::info!("CoreML execution provider registered");
309 Ok((b, DeviceType::CoreML))
310 }
311 Err(e) => {
312 tracing::warn!("Failed to register CoreML EP: {e}, falling back to CPU");
313 let recovered = e.recover();
314 Ok((recovered, DeviceType::Cpu))
315 }
316 }
317}
318
319#[cfg(feature = "directml")]
320fn is_directml_available() -> bool {
321 use ort::ep::{DirectML, ExecutionProvider};
322 DirectML::default().is_available().unwrap_or(false)
323}
324
325#[cfg(feature = "directml")]
326fn configure_directml(
327 builder: ort::session::builder::SessionBuilder,
328 device_id: i32,
329) -> Result<(ort::session::builder::SessionBuilder, DeviceType), PiperError> {
330 let ep = ort::ep::DirectML::default()
331 .with_device_id(device_id)
332 .build();
333 match builder.with_execution_providers([ep]) {
334 Ok(b) => {
335 tracing::info!("DirectML execution provider registered (device_id={device_id})");
336 Ok((b, DeviceType::DirectML { device_id }))
337 }
338 Err(e) => {
339 tracing::warn!("Failed to register DirectML EP: {e}, falling back to CPU");
340 let recovered = e.recover();
341 Ok((recovered, DeviceType::Cpu))
342 }
343 }
344}
345
346#[cfg(feature = "tensorrt")]
347fn is_tensorrt_available() -> bool {
348 use ort::ep::{ExecutionProvider, TensorRT};
349 TensorRT::default().is_available().unwrap_or(false)
350}
351
352#[cfg(feature = "tensorrt")]
353fn configure_tensorrt(
354 builder: ort::session::builder::SessionBuilder,
355 device_id: i32,
356) -> Result<(ort::session::builder::SessionBuilder, DeviceType), PiperError> {
357 let ep = ort::ep::TensorRT::default()
358 .with_device_id(device_id)
359 .build();
360 match builder.with_execution_providers([ep]) {
361 Ok(b) => {
362 tracing::info!("TensorRT execution provider registered (device_id={device_id})");
363 Ok((b, DeviceType::TensorRT { device_id }))
364 }
365 Err(e) => {
366 tracing::warn!("Failed to register TensorRT EP: {e}, falling back to CPU");
367 let recovered = e.recover();
368 Ok((recovered, DeviceType::Cpu))
369 }
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
382 fn test_parse_cpu() {
383 let dt = parse_device_string("cpu").unwrap();
384 assert_eq!(dt, DeviceType::Cpu);
385 }
386
387 #[test]
388 fn test_parse_cpu_uppercase() {
389 let dt = parse_device_string("CPU").unwrap();
390 assert_eq!(dt, DeviceType::Cpu);
391 }
392
393 #[test]
394 fn test_parse_cuda_default() {
395 let dt = parse_device_string("cuda").unwrap();
396 assert_eq!(dt, DeviceType::Cuda { device_id: 0 });
397 }
398
399 #[test]
400 fn test_parse_cuda_device_0() {
401 let dt = parse_device_string("cuda:0").unwrap();
402 assert_eq!(dt, DeviceType::Cuda { device_id: 0 });
403 }
404
405 #[test]
406 fn test_parse_cuda_device_1() {
407 let dt = parse_device_string("cuda:1").unwrap();
408 assert_eq!(dt, DeviceType::Cuda { device_id: 1 });
409 }
410
411 #[test]
412 fn test_parse_cuda_mixed_case() {
413 let dt = parse_device_string("CUDA:2").unwrap();
414 assert_eq!(dt, DeviceType::Cuda { device_id: 2 });
415 }
416
417 #[test]
418 fn test_parse_coreml() {
419 let dt = parse_device_string("coreml").unwrap();
420 assert_eq!(dt, DeviceType::CoreML);
421 }
422
423 #[test]
424 fn test_parse_coreml_uppercase() {
425 let dt = parse_device_string("CoreML").unwrap();
426 assert_eq!(dt, DeviceType::CoreML);
427 }
428
429 #[test]
430 fn test_parse_directml_default() {
431 let dt = parse_device_string("directml").unwrap();
432 assert_eq!(dt, DeviceType::DirectML { device_id: 0 });
433 }
434
435 #[test]
436 fn test_parse_directml_device_2() {
437 let dt = parse_device_string("directml:2").unwrap();
438 assert_eq!(dt, DeviceType::DirectML { device_id: 2 });
439 }
440
441 #[test]
442 fn test_parse_tensorrt_default() {
443 let dt = parse_device_string("tensorrt").unwrap();
444 assert_eq!(dt, DeviceType::TensorRT { device_id: 0 });
445 }
446
447 #[test]
448 fn test_parse_tensorrt_device_0() {
449 let dt = parse_device_string("tensorrt:0").unwrap();
450 assert_eq!(dt, DeviceType::TensorRT { device_id: 0 });
451 }
452
453 #[test]
454 fn test_parse_auto() {
455 let dt = parse_device_string("auto").unwrap();
457 #[cfg(not(any(feature = "cuda", feature = "coreml", feature = "directml")))]
459 assert_eq!(dt, DeviceType::Cpu);
460 #[cfg(any(feature = "cuda", feature = "coreml", feature = "directml"))]
462 let _ = dt; }
464
465 #[test]
470 fn test_parse_invalid_device() {
471 let result = parse_device_string("vulkan");
472 assert!(result.is_err());
473 let err_msg = format!("{}", result.unwrap_err());
474 assert!(err_msg.contains("unknown device"));
475 }
476
477 #[test]
478 fn test_parse_cuda_invalid_id() {
479 let result = parse_device_string("cuda:abc");
480 assert!(result.is_err());
481 let err_msg = format!("{}", result.unwrap_err());
482 assert!(err_msg.contains("invalid CUDA device id"));
483 }
484
485 #[test]
486 fn test_parse_directml_invalid_id() {
487 let result = parse_device_string("directml:xyz");
488 assert!(result.is_err());
489 let err_msg = format!("{}", result.unwrap_err());
490 assert!(err_msg.contains("invalid DirectML device id"));
491 }
492
493 #[test]
494 fn test_parse_tensorrt_invalid_id() {
495 let result = parse_device_string("tensorrt:bad");
496 assert!(result.is_err());
497 let err_msg = format!("{}", result.unwrap_err());
498 assert!(err_msg.contains("invalid TensorRT device id"));
499 }
500
501 #[test]
502 fn test_parse_empty_string() {
503 let result = parse_device_string("");
504 assert!(result.is_err());
505 }
506
507 #[test]
512 fn test_list_devices_contains_cpu() {
513 let devices = list_devices();
514 assert!(!devices.is_empty());
515 assert!(devices.iter().any(|d| d.device_type == DeviceType::Cpu));
516 }
517
518 #[test]
519 fn test_list_devices_cpu_always_available() {
520 let devices = list_devices();
521 let cpu = devices
522 .iter()
523 .find(|d| d.device_type == DeviceType::Cpu)
524 .unwrap();
525 assert!(cpu.available);
526 assert_eq!(cpu.name, "CPU");
527 }
528
529 #[test]
530 fn test_list_devices_first_is_cpu() {
531 let devices = list_devices();
532 assert_eq!(devices[0].device_type, DeviceType::Cpu);
533 }
534
535 #[test]
540 fn test_display_cpu() {
541 assert_eq!(format!("{}", DeviceType::Cpu), "cpu");
542 }
543
544 #[test]
545 fn test_display_cuda() {
546 assert_eq!(format!("{}", DeviceType::Cuda { device_id: 0 }), "cuda:0");
547 assert_eq!(format!("{}", DeviceType::Cuda { device_id: 3 }), "cuda:3");
548 }
549
550 #[test]
551 fn test_display_coreml() {
552 assert_eq!(format!("{}", DeviceType::CoreML), "coreml");
553 }
554
555 #[test]
556 fn test_display_directml() {
557 assert_eq!(
558 format!("{}", DeviceType::DirectML { device_id: 1 }),
559 "directml:1"
560 );
561 }
562
563 #[test]
564 fn test_display_tensorrt() {
565 assert_eq!(
566 format!("{}", DeviceType::TensorRT { device_id: 0 }),
567 "tensorrt:0"
568 );
569 }
570
571 #[test]
576 fn test_device_info_construction() {
577 let info = DeviceInfo {
578 name: "TestGPU".to_string(),
579 device_type: DeviceType::Cuda { device_id: 1 },
580 available: true,
581 };
582 assert_eq!(info.name, "TestGPU");
583 assert_eq!(info.device_type, DeviceType::Cuda { device_id: 1 });
584 assert!(info.available);
585 }
586
587 #[test]
588 fn test_device_info_debug() {
589 let info = DeviceInfo {
590 name: "CPU".to_string(),
591 device_type: DeviceType::Cpu,
592 available: true,
593 };
594 let debug = format!("{:?}", info);
595 assert!(debug.contains("CPU"));
596 assert!(debug.contains("available: true"));
597 }
598
599 #[test]
600 fn test_device_info_clone() {
601 let info = DeviceInfo {
602 name: "CUDA".to_string(),
603 device_type: DeviceType::Cuda { device_id: 0 },
604 available: false,
605 };
606 let cloned = info.clone();
607 assert_eq!(cloned.name, info.name);
608 assert_eq!(cloned.device_type, info.device_type);
609 assert_eq!(cloned.available, info.available);
610 }
611
612 #[test]
617 fn test_device_type_equality() {
618 assert_eq!(DeviceType::Cpu, DeviceType::Cpu);
619 assert_eq!(
620 DeviceType::Cuda { device_id: 0 },
621 DeviceType::Cuda { device_id: 0 }
622 );
623 assert_ne!(
624 DeviceType::Cuda { device_id: 0 },
625 DeviceType::Cuda { device_id: 1 }
626 );
627 assert_ne!(DeviceType::Cpu, DeviceType::CoreML);
628 }
629
630 #[test]
631 fn test_device_type_clone() {
632 let dt = DeviceType::TensorRT { device_id: 2 };
633 let cloned = dt.clone();
634 assert_eq!(dt, cloned);
635 }
636
637 #[cfg(feature = "cuda")]
642 #[test]
643 fn test_cuda_listed_when_feature_enabled() {
644 let devices = list_devices();
645 assert!(
646 devices
647 .iter()
648 .any(|d| matches!(d.device_type, DeviceType::Cuda { .. }))
649 );
650 }
651
652 #[cfg(feature = "coreml")]
653 #[test]
654 fn test_coreml_listed_when_feature_enabled() {
655 let devices = list_devices();
656 assert!(devices.iter().any(|d| d.device_type == DeviceType::CoreML));
657 }
658
659 #[cfg(feature = "directml")]
660 #[test]
661 fn test_directml_listed_when_feature_enabled() {
662 let devices = list_devices();
663 assert!(
664 devices
665 .iter()
666 .any(|d| matches!(d.device_type, DeviceType::DirectML { .. }))
667 );
668 }
669
670 #[cfg(feature = "tensorrt")]
671 #[test]
672 fn test_tensorrt_listed_when_feature_enabled() {
673 let devices = list_devices();
674 assert!(
675 devices
676 .iter()
677 .any(|d| matches!(d.device_type, DeviceType::TensorRT { .. }))
678 );
679 }
680
681 #[test]
686 fn test_configure_cpu_returns_cpu() {
687 let builder = ort::session::Session::builder().expect("session builder");
689 let (_, actual_device) = configure_session_builder(builder, &DeviceType::Cpu).unwrap();
690 assert_eq!(actual_device, DeviceType::Cpu);
691 }
692
693 #[cfg(not(feature = "cuda"))]
698 #[test]
699 fn test_cuda_fallback_without_feature() {
700 let builder = ort::session::Session::builder().expect("session builder");
701 let (_, actual_device) =
702 configure_session_builder(builder, &DeviceType::Cuda { device_id: 0 }).unwrap();
703 assert_eq!(actual_device, DeviceType::Cpu);
704 }
705
706 #[cfg(not(feature = "coreml"))]
707 #[test]
708 fn test_coreml_fallback_without_feature() {
709 let builder = ort::session::Session::builder().expect("session builder");
710 let (_, actual_device) = configure_session_builder(builder, &DeviceType::CoreML).unwrap();
711 assert_eq!(actual_device, DeviceType::Cpu);
712 }
713
714 #[cfg(not(feature = "directml"))]
715 #[test]
716 fn test_directml_fallback_without_feature() {
717 let builder = ort::session::Session::builder().expect("session builder");
718 let (_, actual_device) =
719 configure_session_builder(builder, &DeviceType::DirectML { device_id: 0 }).unwrap();
720 assert_eq!(actual_device, DeviceType::Cpu);
721 }
722
723 #[cfg(not(feature = "tensorrt"))]
724 #[test]
725 fn test_tensorrt_fallback_without_feature() {
726 let builder = ort::session::Session::builder().expect("session builder");
727 let (_, actual_device) =
728 configure_session_builder(builder, &DeviceType::TensorRT { device_id: 0 }).unwrap();
729 assert_eq!(actual_device, DeviceType::Cpu);
730 }
731
732 #[test]
737 fn test_auto_detect_device_returns_valid() {
738 let dt = parse_device_string("auto").unwrap();
739 match dt {
741 DeviceType::Cpu
742 | DeviceType::Cuda { .. }
743 | DeviceType::CoreML
744 | DeviceType::DirectML { .. }
745 | DeviceType::TensorRT { .. } => {} }
747 }
748
749 #[test]
750 fn test_parse_device_string_whitespace() {
751 let dt = parse_device_string(" cuda ").unwrap();
753 assert_eq!(dt, DeviceType::Cuda { device_id: 0 });
754 }
755
756 #[test]
757 fn test_parse_device_string_large_device_id() {
758 let dt = parse_device_string("cuda:999").unwrap();
759 assert_eq!(dt, DeviceType::Cuda { device_id: 999 });
760 }
761
762 #[test]
763 fn test_device_type_default_display_roundtrip() {
764 let variants = vec![
766 DeviceType::Cpu,
767 DeviceType::Cuda { device_id: 0 },
768 DeviceType::Cuda { device_id: 7 },
769 DeviceType::CoreML,
770 DeviceType::DirectML { device_id: 0 },
771 DeviceType::DirectML { device_id: 3 },
772 DeviceType::TensorRT { device_id: 0 },
773 DeviceType::TensorRT { device_id: 5 },
774 ];
775 for variant in variants {
776 let displayed = format!("{variant}");
777 let parsed = parse_device_string(&displayed).unwrap();
778 assert_eq!(parsed, variant, "roundtrip failed for '{displayed}'");
779 }
780 }
781}