ultralytics_inference/
device.rs1use std::fmt;
5use std::str::FromStr;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum Device {
10 Cpu,
12 Cuda(usize),
15 Mps,
17 CoreMl,
19 DirectMl(usize),
22 OpenVino,
24 Xnnpack,
26 TensorRt(usize),
29 Rocm(usize),
32}
33
34impl fmt::Display for Device {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 match self {
37 Self::Cpu => write!(f, "cpu"),
38 Self::Cuda(i) => write!(f, "cuda:{i}"),
39 Self::Mps => write!(f, "mps"),
40 Self::CoreMl => write!(f, "coreml"),
41 Self::DirectMl(i) => write!(f, "directml:{i}"),
42 Self::OpenVino => write!(f, "openvino"),
43 Self::Xnnpack => write!(f, "xnnpack"),
44 Self::TensorRt(i) => write!(f, "tensorrt:{i}"),
45 Self::Rocm(i) => write!(f, "rocm:{i}"),
46 }
47 }
48}
49
50impl FromStr for Device {
51 type Err = String;
52
53 fn from_str(s: &str) -> Result<Self, Self::Err> {
54 let s = s.to_lowercase();
55 if let Some(rest) = s.strip_prefix("cuda") {
56 return Ok(Self::Cuda(parse_device_index(rest)));
57 }
58 if let Some(rest) = s.strip_prefix("directml") {
59 return Ok(Self::DirectMl(parse_device_index(rest)));
60 }
61 if let Some(rest) = s.strip_prefix("tensorrt") {
62 return Ok(Self::TensorRt(parse_device_index(rest)));
63 }
64 if let Some(rest) = s.strip_prefix("rocm") {
65 return Ok(Self::Rocm(parse_device_index(rest)));
66 }
67 match s.as_str() {
68 "cpu" => Ok(Self::Cpu),
69 "mps" => Ok(Self::Mps),
70 "coreml" => Ok(Self::CoreMl),
71 "openvino" => Ok(Self::OpenVino),
72 "xnnpack" => Ok(Self::Xnnpack),
73 _ => Err(format!("Unknown device: {s}")),
74 }
75 }
76}
77
78fn parse_device_index(s: &str) -> usize {
80 s.strip_prefix(':')
81 .and_then(|i| i.parse().ok())
82 .unwrap_or(0)
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88
89 #[test]
90 fn test_parse_device() {
91 assert_eq!(Device::from_str("cpu").unwrap(), Device::Cpu);
92 assert_eq!(Device::from_str("cuda").unwrap(), Device::Cuda(0));
93 assert_eq!(Device::from_str("cuda:0").unwrap(), Device::Cuda(0));
94 assert_eq!(Device::from_str("cuda:1").unwrap(), Device::Cuda(1));
95 assert_eq!(Device::from_str("mps").unwrap(), Device::Mps);
96 assert_eq!(Device::from_str("coreml").unwrap(), Device::CoreMl);
97 assert_eq!(Device::from_str("directml").unwrap(), Device::DirectMl(0));
98 assert_eq!(Device::from_str("directml:1").unwrap(), Device::DirectMl(1));
99 }
100}