1use serde::{Deserialize, Serialize};
7use std::fmt;
8
9#[allow(unused_imports)]
10#[cfg(feature = "onnx")]
11use ort::execution_providers::ExecutionProvider;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
15pub enum Device {
16 Cpu(usize),
18 Cuda(usize),
20}
21
22impl Default for Device {
23 fn default() -> Self {
24 Self::Cpu(0)
25 }
26}
27
28impl fmt::Display for Device {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 match self {
31 Self::Cpu(i) => write!(f, "cpu:{i}"),
32 Self::Cuda(i) => write!(f, "cuda:{i}"),
33 }
34 }
35}
36
37impl std::str::FromStr for Device {
38 type Err = crate::EdgeError;
39
40 fn from_str(s: &str) -> Result<Self, Self::Err> {
41 #[inline]
42 fn parse_device_id(id_str: Option<&str>) -> usize {
43 id_str
44 .map(|s| s.trim().parse::<usize>().unwrap_or(0))
45 .unwrap_or(0)
46 }
47
48 let (device_type, id_part) = s
49 .trim()
50 .split_once(':')
51 .map_or_else(|| (s.trim(), None), |(device, id)| (device, Some(id)));
52
53 match device_type.to_lowercase().as_str() {
54 "cpu" => Ok(Self::Cpu(parse_device_id(id_part))),
55 "cuda" => Ok(Self::Cuda(parse_device_id(id_part))),
56 _ => Err(crate::EdgeError::runtime(format!(
57 "Unsupported device: {s}"
58 ))),
59 }
60 }
61}
62
63impl Device {
64 pub fn id(&self) -> Option<usize> {
66 match self {
67 Self::Cpu(i) | Self::Cuda(i) => Some(*i),
68 }
69 }
70
71 pub fn is_available(&self) -> bool {
73 match self {
74 Self::Cpu(_) => true, Self::Cuda(_) => {
76 #[cfg(all(feature = "onnx", feature = "cuda"))]
77 {
78 use ort::execution_providers::CUDAExecutionProvider;
79 CUDAExecutionProvider::default()
80 .with_device_id(self.id().unwrap_or(0) as i32)
81 .is_available()
82 .unwrap_or(false)
83 }
84 #[cfg(not(all(feature = "onnx", feature = "cuda")))]
85 {
86 false
87 }
88 }
89 }
90 }
91}
92
93pub fn cpu() -> Device {
95 Device::Cpu(1)
96}
97
98pub fn cpu_with_threads(threads: usize) -> Device {
99 Device::Cpu(threads)
100}
101
102pub fn cuda(device_id: usize) -> Device {
103 Device::Cuda(device_id)
104}
105
106pub fn cuda_default() -> Device {
107 Device::Cuda(0)
108}