1use serde::{Deserialize, Serialize};
7use std::fmt;
8
9#[allow(unused_imports)]
10#[cfg(all(feature = "onnx", not(target_arch = "wasm32")))]
11use ort::execution_providers::ExecutionProvider;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
15pub enum Device {
16 Cpu(usize),
18 Cuda(usize),
20 #[cfg(target_arch = "wasm32")]
22 WebGpu,
23}
24
25impl Default for Device {
26 fn default() -> Self {
27 #[cfg(target_arch = "wasm32")]
28 {
29 Self::WebGpu
30 }
31 #[cfg(not(target_arch = "wasm32"))]
32 {
33 Self::Cpu(0)
34 }
35 }
36}
37
38impl fmt::Display for Device {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 Self::Cpu(i) => write!(f, "cpu:{i}"),
42 Self::Cuda(i) => write!(f, "cuda:{i}"),
43 #[cfg(target_arch = "wasm32")]
44 Self::WebGpu => write!(f, "webgpu"),
45 }
46 }
47}
48
49impl std::str::FromStr for Device {
50 type Err = crate::EdgeError;
51
52 fn from_str(s: &str) -> Result<Self, Self::Err> {
53 #[inline]
54 fn parse_device_id(id_str: Option<&str>) -> usize {
55 id_str
56 .map(|s| s.trim().parse::<usize>().unwrap_or(0))
57 .unwrap_or(0)
58 }
59
60 let (device_type, id_part) = s
61 .trim()
62 .split_once(':')
63 .map_or_else(|| (s.trim(), None), |(device, id)| (device, Some(id)));
64
65 match device_type.to_lowercase().as_str() {
66 "cpu" => Ok(Self::Cpu(parse_device_id(id_part))),
67 "cuda" => Ok(Self::Cuda(parse_device_id(id_part))),
68 #[cfg(target_arch = "wasm32")]
69 "webgpu" => Ok(Self::WebGpu),
70 _ => Err(crate::EdgeError::runtime(format!(
71 "Unsupported device: {s}"
72 ))),
73 }
74 }
75}
76
77impl Device {
78 pub fn id(&self) -> Option<usize> {
80 match self {
81 Self::Cpu(i) | Self::Cuda(i) => Some(*i),
82 #[cfg(target_arch = "wasm32")]
83 Self::WebGpu => None,
84 }
85 }
86
87 pub fn is_available(&self) -> bool {
89 match self {
90 Self::Cpu(_) => true, Self::Cuda(_) => {
92 #[cfg(all(feature = "onnx", feature = "cuda"))]
93 {
94 use ort::execution_providers::CUDAExecutionProvider;
95 CUDAExecutionProvider::default()
96 .with_device_id(self.id().unwrap_or(0) as i32)
97 .is_available()
98 .unwrap_or(false)
99 }
100 #[cfg(not(all(feature = "onnx", feature = "cuda")))]
101 {
102 false
103 }
104 }
105 #[cfg(target_arch = "wasm32")]
106 Self::WebGpu => {
107 cfg!(feature = "onnx")
110 }
111 }
112 }
113}
114
115pub fn cpu() -> Device {
117 Device::Cpu(1)
118}
119
120pub fn cpu_with_threads(threads: usize) -> Device {
121 Device::Cpu(threads)
122}
123
124pub fn cuda(device_id: usize) -> Device {
125 Device::Cuda(device_id)
126}
127
128pub fn cuda_default() -> Device {
129 Device::Cuda(0)
130}
131
132#[cfg(target_arch = "wasm32")]
134pub fn webgpu() -> Device {
135 Device::WebGpu
136}