Skip to main content

ultralytics_inference/
device.rs

1// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
3//! Hardware device support and abstraction.
4use std::fmt;
5use std::str::FromStr;
6
7/// Hardware device for inference.
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum Device {
10    /// CPU (Central Processing Unit).
11    Cpu,
12    /// CUDA (Compute Unified Device Architecture) for NVIDIA GPUs.
13    /// The argument specifies the device index (e.g., 0 for the first GPU).
14    Cuda(usize),
15    /// MPS (Metal Performance Shaders) for Apple Silicon (macOS).
16    Mps,
17    /// `CoreML` (Apple Core Machine Learning).
18    CoreMl,
19    /// `DirectML` (Direct Machine Learning) for Windows.
20    /// The argument specifies the device index.
21    DirectMl(usize),
22    /// `OpenVINO` (Open Visual Inference and Neural Network Optimization) for Intel hardware.
23    OpenVino,
24    /// XNNPACK (optimized floating-point neural network inference operators) for CPU.
25    Xnnpack,
26    /// `TensorRT` (NVIDIA `TensorRT`) for high-performance deep learning inference.
27    /// The argument specifies the device index.
28    TensorRt(usize),
29    /// `ROCm` (Radeon Open Compute) for AMD GPUs.
30    /// The argument specifies the device index.
31    Rocm(usize),
32}
33
34impl fmt::Display for Device {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        match self {
37            Self::Cpu => write!(f, "cpu"),
38            Self::Cuda(i) => write!(f, "cuda:{i}"),
39            Self::Mps => write!(f, "mps"),
40            Self::CoreMl => write!(f, "coreml"),
41            Self::DirectMl(i) => write!(f, "directml:{i}"),
42            Self::OpenVino => write!(f, "openvino"),
43            Self::Xnnpack => write!(f, "xnnpack"),
44            Self::TensorRt(i) => write!(f, "tensorrt:{i}"),
45            Self::Rocm(i) => write!(f, "rocm:{i}"),
46        }
47    }
48}
49
50impl FromStr for Device {
51    type Err = String;
52
53    fn from_str(s: &str) -> Result<Self, Self::Err> {
54        let s = s.to_lowercase();
55        if let Some(rest) = s.strip_prefix("cuda") {
56            return Ok(Self::Cuda(parse_device_index(rest)));
57        }
58        if let Some(rest) = s.strip_prefix("directml") {
59            return Ok(Self::DirectMl(parse_device_index(rest)));
60        }
61        if let Some(rest) = s.strip_prefix("tensorrt") {
62            return Ok(Self::TensorRt(parse_device_index(rest)));
63        }
64        if let Some(rest) = s.strip_prefix("rocm") {
65            return Ok(Self::Rocm(parse_device_index(rest)));
66        }
67        match s.as_str() {
68            "cpu" => Ok(Self::Cpu),
69            "mps" => Ok(Self::Mps),
70            "coreml" => Ok(Self::CoreMl),
71            "openvino" => Ok(Self::OpenVino),
72            "xnnpack" => Ok(Self::Xnnpack),
73            _ => Err(format!("Unknown device: {s}")),
74        }
75    }
76}
77
78/// Parse a trailing device index like `":0"`, defaulting to `0` when absent.
79fn parse_device_index(s: &str) -> usize {
80    s.strip_prefix(':')
81        .and_then(|i| i.parse().ok())
82        .unwrap_or(0)
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88
89    #[test]
90    fn test_parse_device() {
91        assert_eq!(Device::from_str("cpu").unwrap(), Device::Cpu);
92        assert_eq!(Device::from_str("cuda").unwrap(), Device::Cuda(0));
93        assert_eq!(Device::from_str("cuda:0").unwrap(), Device::Cuda(0));
94        assert_eq!(Device::from_str("cuda:1").unwrap(), Device::Cuda(1));
95        assert_eq!(Device::from_str("mps").unwrap(), Device::Mps);
96        assert_eq!(Device::from_str("coreml").unwrap(), Device::CoreMl);
97        assert_eq!(Device::from_str("directml").unwrap(), Device::DirectMl(0));
98        assert_eq!(Device::from_str("directml:1").unwrap(), Device::DirectMl(1));
99    }
100}