autoagents_onnx/
device.rs1use serde::{Deserialize, Serialize};
7use std::fmt;
8
9#[allow(unused_imports)]
10use ort::execution_providers::ExecutionProvider;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
14pub enum Device {
15 Cpu(usize),
17 #[cfg(feature = "cuda")]
19 Cuda(usize),
20}
21
22impl Default for Device {
23 fn default() -> Self {
24 {
25 Self::Cpu(0)
26 }
27 }
28}
29
30impl fmt::Display for Device {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 match self {
33 Self::Cpu(i) => write!(f, "cpu:{i}"),
34 #[cfg(feature = "cuda")]
35 Self::Cuda(i) => write!(f, "cuda:{i}"),
36 }
37 }
38}
39
40impl std::str::FromStr for Device {
41 type Err = crate::EdgeError;
42
43 fn from_str(s: &str) -> Result<Self, Self::Err> {
44 #[inline]
45 fn parse_device_id(id_str: Option<&str>) -> usize {
46 id_str
47 .map(|s| s.trim().parse::<usize>().unwrap_or(0))
48 .unwrap_or(0)
49 }
50
51 let (device_type, id_part) = s
52 .trim()
53 .split_once(':')
54 .map_or_else(|| (s.trim(), None), |(device, id)| (device, Some(id)));
55
56 match device_type.to_lowercase().as_str() {
57 "cpu" => Ok(Self::Cpu(parse_device_id(id_part))),
58 #[cfg(feature = "cuda")]
59 "cuda" => Ok(Self::Cuda(parse_device_id(id_part))),
60 _ => Err(crate::EdgeError::runtime(format!(
61 "Unsupported device: {s}"
62 ))),
63 }
64 }
65}
66
67impl Device {
68 pub fn id(&self) -> Option<usize> {
70 match self {
71 Self::Cpu(i) => Some(*i),
72 #[cfg(feature = "cuda")]
73 Self::Cuda(i) => Some(*i),
74 }
75 }
76
77 pub fn is_available(&self) -> bool {
79 match self {
80 Self::Cpu(_) => true, #[cfg(feature = "cuda")]
82 Self::Cuda(_) => {
83 use ort::execution_providers::CUDAExecutionProvider;
84 CUDAExecutionProvider::default()
85 .with_device_id(self.id().unwrap_or(0) as i32)
86 .is_available()
87 .unwrap_or(false)
88 }
89 }
90 }
91}
92
93pub fn cpu() -> Device {
95 Device::Cpu(1)
96}
97
98pub fn cpu_with_threads(threads: usize) -> Device {
99 Device::Cpu(threads)
100}
101
102#[cfg(feature = "cuda")]
103pub fn cuda(device_id: usize) -> Device {
104 Device::Cuda(device_id)
105}
106
107#[cfg(feature = "cuda")]
108pub fn cuda_default() -> Device {
109 Device::Cuda(0)
110}